train
This module provides training utilities for 4D-VarNet models using PyTorch Lightning.
Functions:
Name | Description |
---|---|
base_training |
Perform basic training and testing of a model with a single datamodule. |
multi_dm_training |
Perform training and testing with support for multiple datamodules. |
base_training(trainer, dm, lit_mod, ckpt=None)
Perform basic training and testing of a model with a single datamodule.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
trainer
|
Trainer
|
The PyTorch Lightning trainer instance. |
required |
dm
|
LightningDataModule
|
The datamodule for training and testing. |
required |
lit_mod
|
LightningModule
|
The Lightning module to train. |
required |
ckpt
|
str
|
Path to a checkpoint to resume training from. |
None
|
Returns:
Type | Description |
---|---|
None |
Source code in ocean4dvarnet/train.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
|
multi_dm_training(trainer, dm, lit_mod, test_dm=None, test_fn=None, ckpt=None)
Perform training and testing with support for multiple datamodules.
This function trains the model using the provided datamodule and optionally tests it on a separate test datamodule. It also supports custom test functions for evaluation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
trainer
|
Trainer
|
The PyTorch Lightning trainer instance. |
required |
dm
|
LightningDataModule
|
The datamodule for training. |
required |
lit_mod
|
LightningModule
|
The Lightning module to train. |
required |
test_dm
|
LightningDataModule
|
The datamodule for testing. Defaults to |
None
|
test_fn
|
callable
|
A custom function to evaluate the model after testing. |
None
|
ckpt
|
str
|
Path to a checkpoint to resume training from. |
None
|
Returns:
Type | Description |
---|---|
None |
Source code in ocean4dvarnet/train.py
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
|