Skip to content

train

This module provides training utilities for 4D-VarNet models using PyTorch Lightning.

Functions:

Name Description
base_training

Perform basic training and testing of a model with a single datamodule.

multi_dm_training

Perform training and testing with support for multiple datamodules.

base_training(trainer, dm, lit_mod, ckpt=None)

Perform basic training and testing of a model with a single datamodule.

Parameters:

Name Type Description Default
trainer Trainer

The PyTorch Lightning trainer instance.

required
dm LightningDataModule

The datamodule for training and testing.

required
lit_mod LightningModule

The Lightning module to train.

required
ckpt str

Path to a checkpoint to resume training from.

None

Returns:

Type Description

None

Source code in ocean4dvarnet/train.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def base_training(trainer, dm, lit_mod, ckpt=None):
    """
    Perform basic training and testing of a model with a single datamodule.

    Args:
        trainer (pl.Trainer): The PyTorch Lightning trainer instance.
        dm (pl.LightningDataModule): The datamodule for training and testing.
        lit_mod (pl.LightningModule): The Lightning module to train.
        ckpt (str, optional): Path to a checkpoint to resume training from.

    Returns:
        None
    """
    if trainer.logger is not None:
        print()
        print("Logdir:", trainer.logger.log_dir)
        print()

    trainer.fit(lit_mod, datamodule=dm, ckpt_path=ckpt)
    trainer.test(lit_mod, datamodule=dm, ckpt_path='best')

multi_dm_training(trainer, dm, lit_mod, test_dm=None, test_fn=None, ckpt=None)

Perform training and testing with support for multiple datamodules.

This function trains the model using the provided datamodule and optionally tests it on a separate test datamodule. It also supports custom test functions for evaluation.

Parameters:

Name Type Description Default
trainer Trainer

The PyTorch Lightning trainer instance.

required
dm LightningDataModule

The datamodule for training.

required
lit_mod LightningModule

The Lightning module to train.

required
test_dm LightningDataModule

The datamodule for testing. Defaults to dm.

None
test_fn callable

A custom function to evaluate the model after testing.

None
ckpt str

Path to a checkpoint to resume training from.

None

Returns:

Type Description

None

Source code in ocean4dvarnet/train.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def multi_dm_training(
    trainer, dm, lit_mod, test_dm=None, test_fn=None, ckpt=None
):
    """
    Perform training and testing with support for multiple datamodules.

    This function trains the model using the provided datamodule and optionally tests it
    on a separate test datamodule. It also supports custom test functions for evaluation.

    Args:
        trainer (pl.Trainer): The PyTorch Lightning trainer instance.
        dm (pl.LightningDataModule): The datamodule for training.
        lit_mod (pl.LightningModule): The Lightning module to train.
        test_dm (pl.LightningDataModule, optional): The datamodule for testing. Defaults to `dm`.
        test_fn (callable, optional): A custom function to evaluate the model after testing.
        ckpt (str, optional): Path to a checkpoint to resume training from.

    Returns:
        None
    """
    if trainer.logger is not None:
        print()
        print("Logdir:", trainer.logger.log_dir)
        print()

    trainer.fit(lit_mod, datamodule=dm, ckpt_path=ckpt)

    if test_fn is not None:
        if test_dm is None:
            test_dm = dm
        lit_mod._norm_stats = test_dm.norm_stats()

        best_ckpt_path = trainer.checkpoint_callback.best_model_path
        trainer.callbacks = []
        trainer.test(lit_mod, datamodule=test_dm, ckpt_path=best_ckpt_path)

        print("\nBest ckpt score:")
        print(test_fn(lit_mod).to_markdown())
        print("\n###############")