Skip to content

utils

This module provides utility functions for 4D-VarNet.

Utility functions include data preprocessing, optimization configuration, diagnostics, and evaluation metrics.

Functions:

Name Description
pipe

Apply a sequence of functions to an input.

kwgetattr

Get an attribute of an object by name.

callmap

Apply a list of functions to an input and return the results.

half_lr_adam

Configure an Adam optimizer with specific learning rates for model components.

cosanneal_lr_adam

Configure an Adam optimizer with cosine annealing learning rate scheduling.

cosanneal_lr_lion

Configure a Lion optimizer with cosine annealing learning rate scheduling.

triang_lr_adam

Configure an Adam optimizer with triangular cyclic learning rate scheduling.

remove_nan

Fill NaN values in a DataArray using Gauss-Seidel interpolation.

get_constant_crop

Generate a constant cropping mask for patches.

get_cropped_hanning_mask

Generate a cropped Hanning mask for patches.

get_triang_time_wei

Generate a triangular time weighting mask for patches.

load_enatl

Load ENATL dataset and preprocess it.

load_altimetry_data

Load altimetry data and preprocess it.

load_dc_data

Load DC data (currently a placeholder function).

load_full_natl_data

Load full NATL dataset and preprocess it.

rmse_based_scores_from_ds

Compute RMSE-based scores from a dataset.

psd_based_scores_from_ds

Compute PSD-based scores from a dataset.

rmse_based_scores

Compute RMSE-based scores for reconstruction evaluation.

psd_based_scores

Compute PSD-based scores for reconstruction evaluation.

diagnostics

Compute diagnostics for a given test domain.

diagnostics_from_ds

Compute diagnostics from a dataset.

test_osse

Perform OSSE testing and compute metrics.

ensemble_metrics

Compute ensemble metrics for multiple checkpoints.

add_geo_attrs

Add geographic attributes to a DataArray.

vort

Compute vorticity from a DataArray.

geo_energy

Compute geostrophic energy from a DataArray.

best_ckpt

Retrieve the best checkpoint from an experiment directory.

load_cfg

Load configuration files for an experiment.

add_geo_attrs(da)

Add geographic attributes (longitude and latitude units) to a DataArray.

Parameters:

Name Type Description Default
da DataArray

The input DataArray.

required

Returns:

Type Description

xarray.DataArray: The DataArray with geographic attributes added.

Source code in ocean4dvarnet/utils.py
635
636
637
638
639
640
641
642
643
644
645
646
647
def add_geo_attrs(da):
    """
    Add geographic attributes (longitude and latitude units) to a DataArray.

    Args:
        da (xarray.DataArray): The input DataArray.

    Returns:
        xarray.DataArray: The DataArray with geographic attributes added.
    """
    da["lon"] = da.lon.assign_attrs(units="degrees_east")
    da["lat"] = da.lat.assign_attrs(units="degrees_north")
    return da

best_ckpt(xp_dir)

Retrieve the best checkpoint from an experiment directory.

Parameters:

Name Type Description Default
xp_dir str

Path to the experiment directory.

required

Returns:

Name Type Description
str

Path to the best checkpoint file.

Source code in ocean4dvarnet/utils.py
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
def best_ckpt(xp_dir):
    """
    Retrieve the best checkpoint from an experiment directory.

    Args:
        xp_dir (str): Path to the experiment directory.

    Returns:
        str: Path to the best checkpoint file.
    """
    _, xpn = load_cfg(xp_dir)
    if xpn is None:
        return None
    print(Path(xp_dir) / xpn / 'checkpoints')
    ckpt_last = max(
        (Path(xp_dir) / xpn / 'checkpoints').glob("*.ckpt"), key=lambda p: p.stat().st_mtime
    )
    cbs = torch.load(ckpt_last)["callbacks"]
    ckpt_cb = cbs[next(k for k in cbs.keys() if "ModelCheckpoint" in k)]
    return ckpt_cb["best_model_path"]

callmap(inp, fns)

Apply a list of functions to an input and return the results.

Parameters:

Name Type Description Default
inp

The input to process.

required
fns list

A list of functions to apply.

required

Returns:

Name Type Description
list

A list of results from applying each function.

Source code in ocean4dvarnet/utils.py
85
86
87
88
89
90
91
92
93
94
95
96
def callmap(inp, fns):
    """
    Apply a list of functions to an input and return the results.

    Args:
        inp: The input to process.
        fns (list): A list of functions to apply.

    Returns:
        list: A list of results from applying each function.
    """
    return [fn(inp) for fn in fns]

cosanneal_lr_adam(lit_mod, lr, T_max=100, weight_decay=0.0)

Configure an Adam optimizer with cosine annealing learning rate scheduling.

Parameters:

Name Type Description Default
lit_mod

The Lightning module containing the model.

required
lr float

The base learning rate.

required
T_max int

Maximum number of iterations for the scheduler.

100
weight_decay float

Weight decay for the optimizer.

0.0

Returns:

Name Type Description
dict

A dictionary containing the optimizer and scheduler.

Source code in ocean4dvarnet/utils.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def cosanneal_lr_adam(lit_mod, lr, T_max=100, weight_decay=0.):
    """
    Configure an Adam optimizer with cosine annealing learning rate scheduling.

    Args:
        lit_mod: The Lightning module containing the model.
        lr (float): The base learning rate.
        T_max (int): Maximum number of iterations for the scheduler.
        weight_decay (float): Weight decay for the optimizer.

    Returns:
        dict: A dictionary containing the optimizer and scheduler.
    """
    opt = torch.optim.Adam(
        [
            {"params": lit_mod.solver.grad_mod.parameters(), "lr": lr},
            {"params": lit_mod.solver.obs_cost.parameters(), "lr": lr},
            {"params": lit_mod.solver.prior_cost.parameters(), "lr": lr / 2},
        ], weight_decay=weight_decay
    )
    return {
        "optimizer": opt,
        "lr_scheduler": torch.optim.lr_scheduler.CosineAnnealingLR(
            opt, T_max=T_max
        ),
    }

cosanneal_lr_lion(lit_mod, lr, T_max=100)

Configure a Lion optimizer with cosine annealing learning rate scheduling.

Parameters:

Name Type Description Default
lit_mod

The Lightning module containing the model.

required
lr float

The base learning rate.

required
T_max int

Maximum number of iterations for the scheduler.

100

Returns:

Name Type Description
dict

A dictionary containing the optimizer and scheduler.

Source code in ocean4dvarnet/utils.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def cosanneal_lr_lion(lit_mod, lr, T_max=100):
    """
    Configure a Lion optimizer with cosine annealing learning rate scheduling.

    Args:
        lit_mod: The Lightning module containing the model.
        lr (float): The base learning rate.
        T_max (int): Maximum number of iterations for the scheduler.

    Returns:
        dict: A dictionary containing the optimizer and scheduler.
    """
    import lion_pytorch
    opt = lion_pytorch.Lion(
        [
            {"params": lit_mod.solver.grad_mod.parameters(), "lr": lr},
            {"params": lit_mod.solver.prior_cost.parameters(), "lr": lr / 2},
        ], weight_decay=1e-3
    )
    return {
        "optimizer": opt,
        "lr_scheduler": torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=T_max),
    }

diagnostics(lit_mod, test_domain)

Compute diagnostics for a given test domain.

Parameters:

Name Type Description Default
lit_mod

The Lightning module containing the model.

required
test_domain dict

The test domain to evaluate.

required

Returns:

Type Description

pandas.Series: A series containing diagnostic metrics.

Source code in ocean4dvarnet/utils.py
507
508
509
510
511
512
513
514
515
516
517
518
519
def diagnostics(lit_mod, test_domain):
    """
    Compute diagnostics for a given test domain.

    Args:
        lit_mod: The Lightning module containing the model.
        test_domain (dict): The test domain to evaluate.

    Returns:
        pandas.Series: A series containing diagnostic metrics.
    """
    test_data = lit_mod.test_data.sel(test_domain)
    return diagnostics_from_ds(test_data, test_domain)

diagnostics_from_ds(test_data, test_domain)

Compute diagnostics from a dataset.

Parameters:

Name Type Description Default
test_data Dataset

The test data.

required
test_domain dict

The test domain to evaluate.

required

Returns:

Type Description

pandas.Series: A series containing diagnostic metrics.

Source code in ocean4dvarnet/utils.py
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
def diagnostics_from_ds(test_data, test_domain):
    """
    Compute diagnostics from a dataset.

    Args:
        test_data (xarray.Dataset): The test data.
        test_domain (dict): The test domain to evaluate.

    Returns:
        pandas.Series: A series containing diagnostic metrics.
    """
    test_data = test_data.sel(test_domain)
    metrics = {
        "RMSE (m)": test_data.pipe(lambda ds: (ds.out - ds.tgt))
        .pipe(lambda da: da**2)
        .mean()
        .pipe(np.sqrt)
        .item(),
        **dict(
            zip(
                ["λx", "λt"],
                test_data.pipe(lambda ds: psd_based_scores(ds.out, ds.tgt)[1:]),
            )
        ),
        **dict(
            zip(
                ["μ", "σ"],
                test_data.pipe(lambda ds: rmse_based_scores(ds.out, ds.tgt)[2:]),
            )
        ),
    }
    return pd.Series(metrics, name="osse_metrics")

ensemble_metrics(trainer, lit_mod, ckpt_list, dm, save_path)

Compute ensemble metrics for multiple checkpoints.

Parameters:

Name Type Description Default
trainer Trainer

The PyTorch Lightning trainer instance.

required
lit_mod LightningModule

The Lightning module to test.

required
ckpt_list list

List of checkpoint paths to evaluate.

required
dm LightningDataModule

The datamodule for testing.

required
save_path str

Path to save the metrics and ensemble outputs.

required

Returns:

Type Description

None

Source code in ocean4dvarnet/utils.py
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
def ensemble_metrics(trainer, lit_mod, ckpt_list, dm, save_path):
    """
    Compute ensemble metrics for multiple checkpoints.

    Args:
        trainer (pl.Trainer): The PyTorch Lightning trainer instance.
        lit_mod (pl.LightningModule): The Lightning module to test.
        ckpt_list (list): List of checkpoint paths to evaluate.
        dm (pl.LightningDataModule): The datamodule for testing.
        save_path (str): Path to save the metrics and ensemble outputs.

    Returns:
        None
    """
    metrics = []
    test_data = xr.Dataset()
    for i, ckpt in enumerate(ckpt_list):
        trainer.test(lit_mod, ckpt_path=ckpt, datamodule=dm)
        rmse = (
            lit_mod.test_data.pipe(lambda ds: (ds.out - ds.ssh))
            .pipe(lambda da: da**2)
            .mean()
            .pipe(np.sqrt)
            .item()
        )
        lx, lt = psd_based_scores(lit_mod.test_data.out, lit_mod.test_data.ssh)[1:]
        mu, sig = rmse_based_scores(lit_mod.test_data.out, lit_mod.test_data.ssh)[2:]

        metrics.append(dict(ckpt=ckpt, rmse=rmse, lx=lx, lt=lt, mu=mu, sig=sig))

        if i == 0:
            test_data = lit_mod.test_data
            test_data = test_data.rename(out=f"out_{i}")
        else:
            test_data = test_data.assign(**{f"out_{i}": lit_mod.test_data.out})
        test_data[f"out_{i}"] = test_data[f"out_{i}"].assign_attrs(
            ckpt=str(ckpt)
        )

    metric_df = pd.DataFrame(metrics)
    print(metric_df.to_markdown())
    print(metric_df.describe().to_markdown())
    metric_df.to_csv(save_path + "/metrics.csv")
    test_data.to_netcdf(save_path + "ens_out.nc")

geo_energy(da)

Compute the geostrophic energy from a DataArray.

Parameters:

Name Type Description Default
da DataArray

The input DataArray.

required

Returns:

Type Description

xarray.DataArray: The geostrophic energy computed from the input data.

Source code in ocean4dvarnet/utils.py
667
668
669
670
671
672
673
674
675
676
677
def geo_energy(da):
    """
    Compute the geostrophic energy from a DataArray.

    Args:
        da (xarray.DataArray): The input DataArray.

    Returns:
        xarray.DataArray: The geostrophic energy computed from the input data.
    """
    return np.hypot(*mpcalc.geostrophic_wind(da.pipe(add_geo_attrs))).metpy.dequantify()

get_constant_crop(patch_dims, crop, dim_order=['time', 'lat', 'lon'])

Generate a constant cropping mask for patches.

Parameters:

Name Type Description Default
patch_dims dict

Dimensions of the patch.

required
crop dict

Crop sizes for each dimension.

required
dim_order list

Order of dimensions.

['time', 'lat', 'lon']

Returns:

Type Description

numpy.ndarray: A mask with cropped regions set to 0 and others to 1.

Source code in ocean4dvarnet/utils.py
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def get_constant_crop(patch_dims, crop, dim_order=["time", "lat", "lon"]):
    """
    Generate a constant cropping mask for patches.

    Args:
        patch_dims (dict): Dimensions of the patch.
        crop (dict): Crop sizes for each dimension.
        dim_order (list): Order of dimensions.

    Returns:
        numpy.ndarray: A mask with cropped regions set to 0 and others to 1.
    """
    patch_weight = np.zeros([patch_dims[d] for d in dim_order], dtype="float32")
    mask = tuple(
        slice(crop[d], -crop[d]) if crop.get(d, 0) > 0 else slice(None, None)
        for d in dim_order
    )
    patch_weight[mask] = 1.0
    return patch_weight

get_cropped_hanning_mask(patch_dims, crop, **kwargs)

Generate a cropped Hanning mask for patches.

Parameters:

Name Type Description Default
patch_dims dict

Dimensions of the patch.

required
crop dict

Crop sizes for each dimension.

required

Returns:

Type Description

numpy.ndarray: The cropped Hanning mask.

Source code in ocean4dvarnet/utils.py
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
def get_cropped_hanning_mask(patch_dims, crop, **kwargs):
    """
    Generate a cropped Hanning mask for patches.

    Args:
        patch_dims (dict): Dimensions of the patch.
        crop (dict): Crop sizes for each dimension.

    Returns:
        numpy.ndarray: The cropped Hanning mask.
    """
    pw = get_constant_crop(patch_dims, crop)
    t_msk = kornia.filters.get_hanning_kernel1d(patch_dims["time"])
    patch_weight = t_msk[:, None, None] * pw
    return patch_weight.cpu().numpy()

get_triang_time_wei(patch_dims, offset=0, **crop_kw)

Generate a triangular time weighting mask for patches.

Parameters:

Name Type Description Default
patch_dims dict

Dimensions of the patch.

required
offset int

Offset for the triangular weighting.

0
crop_kw dict

Additional cropping parameters.

{}

Returns:

Type Description

numpy.ndarray: The triangular time weighting mask.

Source code in ocean4dvarnet/utils.py
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
def get_triang_time_wei(patch_dims, offset=0, **crop_kw):
    """
    Generate a triangular time weighting mask for patches.

    Args:
        patch_dims (dict): Dimensions of the patch.
        offset (int): Offset for the triangular weighting.
        crop_kw (dict): Additional cropping parameters.

    Returns:
        numpy.ndarray: The triangular time weighting mask.
    """
    pw = get_constant_crop(patch_dims, **crop_kw)
    return np.fromfunction(
        lambda t, *a: (
            (1 - np.abs(offset + 2 * t - patch_dims["time"]) / patch_dims["time"]) * pw
        ),
        patch_dims.values(),
    )

half_lr_adam(lit_mod, lr)

Configure an Adam optimizer with specific learning rates for model components.

Parameters:

Name Type Description Default
lit_mod

The Lightning module containing the model.

required
lr float

The base learning rate.

required

Returns:

Type Description

torch.optim.Adam: The configured optimizer.

Source code in ocean4dvarnet/utils.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def half_lr_adam(lit_mod, lr):
    """
    Configure an Adam optimizer with specific learning rates for model components.

    Args:
        lit_mod: The Lightning module containing the model.
        lr (float): The base learning rate.

    Returns:
        torch.optim.Adam: The configured optimizer.
    """
    return torch.optim.Adam(
        [
            {"params": lit_mod.solver.grad_mod.parameters(), "lr": lr},
            {"params": lit_mod.solver.obs_cost.parameters(), "lr": lr},
            {"params": lit_mod.solver.prior_cost.parameters(), "lr": lr / 2},
        ],
    )

kwgetattr(obj, name)

Get an attribute of an object by name.

Parameters:

Name Type Description Default
obj

The object to query.

required
name str

The name of the attribute.

required

Returns:

Type Description

The value of the attribute.

Source code in ocean4dvarnet/utils.py
71
72
73
74
75
76
77
78
79
80
81
82
def kwgetattr(obj, name):
    """
    Get an attribute of an object by name.

    Args:
        obj: The object to query.
        name (str): The name of the attribute.

    Returns:
        The value of the attribute.
    """
    return getattr(obj, name)

load_altimetry_data(path, obs_from_tgt=False)

Load and preprocess altimetry data.

Parameters:

Name Type Description Default
path str

Path to the altimetry dataset.

required
obs_from_tgt bool

Whether to use target data as observations.

False

Returns:

Type Description

xarray.DataArray: The preprocessed altimetry dataset.

Source code in ocean4dvarnet/utils.py
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
def load_altimetry_data(path, obs_from_tgt=False):
    """
    Load and preprocess altimetry data.

    Args:
        path (str): Path to the altimetry dataset.
        obs_from_tgt (bool): Whether to use target data as observations.

    Returns:
        xarray.DataArray: The preprocessed altimetry dataset.
    """
    ds = (
        xr.open_dataset(path)
        # .assign(ssh=lambda ds: ds.ssh.coarsen(lon=2, lat=2).mean().interp(lat=ds.lat, lon=ds.lon))
        .load()
        .assign(
            input=lambda ds: ds.nadir_obs,
            tgt=lambda ds: remove_nan(ds.ssh),
        )
    )

    if obs_from_tgt:
        ds = ds.assign(input=ds.tgt.where(np.isfinite(ds.input), np.nan))

    return (
        ds[[*data.TrainingItem._fields]]
        .transpose("time", "lat", "lon")
        .to_array()
    )

load_cfg(xp_dir)

Load configuration files for an experiment.

Parameters:

Name Type Description Default
xp_dir str

Path to the experiment directory.

required

Returns:

Name Type Description
tuple

A tuple containing the configuration and the experiment name.

Source code in ocean4dvarnet/utils.py
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
def load_cfg(xp_dir):
    """
    Load configuration files for an experiment.

    Args:
        xp_dir (str): Path to the experiment directory.

    Returns:
        tuple: A tuple containing the configuration and the experiment name.
    """
    hydra_cfg = OmegaConf.load(Path(xp_dir) / ".hydra/hydra.yaml").hydra
    cfg = OmegaConf.load(Path(xp_dir) / ".hydra/config.yaml")
    OmegaConf.register_new_resolver(
        "hydra", lambda k: OmegaConf.select(hydra_cfg, k), replace=True
    )
    try:
        OmegaConf.resolve(cfg)
        OmegaConf.resolve(cfg)
    except Exception:
        return None, None

    return cfg, OmegaConf.select(hydra_cfg, "runtime.choices.xp")

load_dc_data(**kwargs)

Load DC data.

This is currently a placeholder function for loading DC data.

Returns:

Type Description

None

Source code in ocean4dvarnet/utils.py
346
347
348
349
350
351
352
353
354
355
356
357
358
359
def load_dc_data(**kwargs):
    """
    Load DC data.

    This is currently a placeholder function for loading DC data.

    Args:
        kwargs

    Returns:
        None
    """
    path_gt = "../sla-data-registry/NATL60/NATL/ref_new/NATL60-CJM165_NATL_ssh_y2013.1y.nc",
    path_obs = "NATL60/NATL/data_new/dataset_nadir_0d.nc"

load_enatl(*args, obs_from_tgt=True, **kwargs)

Load and preprocess the ENATL dataset.

Parameters:

Name Type Description Default
obs_from_tgt bool

Whether to use target data as observations.

True

Returns:

Type Description

xarray.DataArray: The preprocessed ENATL dataset.

Source code in ocean4dvarnet/utils.py
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
def load_enatl(*args, obs_from_tgt=True, **kwargs):
    """
    Load and preprocess the ENATL dataset.

    Args:
        obs_from_tgt (bool): Whether to use target data as observations.

    Returns:
        xarray.DataArray: The preprocessed ENATL dataset.
    """
    # ds = xr.open_dataset('../sla-data-registry/qdata/enatl_wo_tide.nc')
    # print(ds)
    # return ds.rename(nadir_obs='input', ssh='tgt')\
    #     .to_array()\
    #     .transpose('variable', 'time', 'lat', 'lon')\
    #     .sortby('variable')
    ssh = xr.open_zarr('../sla-data-registry/enatl_preproc/truth_SLA_SSH_NATL60.zarr/').ssh
    nadirs = xr.open_zarr('../sla-data-registry/enatl_preproc/SLA_SSH_5nadirs.zarr/').ssh
    ssh = ssh.interp(
        lon=np.arange(ssh.lon.min(), ssh.lon.max(), 1/20),
        lat=np.arange(ssh.lat.min(), ssh.lat.max(), 1/20)
    )
    nadirs = nadirs.interp(time=ssh.time, method='nearest')\
        .interp(lat=ssh.lat, lon=ssh.lon, method='zero')
    ds = xr.Dataset(dict(input=nadirs, tgt=(ssh.dims, ssh.values)), nadirs.coords)
    if obs_from_tgt:
        ds = ds.assign(input=ds.tgt.transpose(*ds.input.dims).where(np.isfinite(ds.input), np.nan))
    return ds.transpose('time', 'lat', 'lon').to_array().load().sortby('variable')

load_full_natl_data(path_obs='../sla-data-registry/CalData/cal_data_new_errs.nc', path_gt='../sla-data-registry/NATL60/NATL/ref_new/NATL60-CJM165_NATL_ssh_y2013.1y.nc', obs_var='five_nadirs', gt_var='ssh', **kwargs)

Load and preprocess the full NATL dataset.

Parameters:

Name Type Description Default
path_obs str

Path to the observation dataset.

'../sla-data-registry/CalData/cal_data_new_errs.nc'
path_gt str

Path to the ground truth dataset.

'../sla-data-registry/NATL60/NATL/ref_new/NATL60-CJM165_NATL_ssh_y2013.1y.nc'
obs_var str

Observation variable name.

'five_nadirs'
gt_var str

Ground truth variable name.

'ssh'

Returns:

Type Description

xarray.DataArray: The preprocessed NATL dataset.

Source code in ocean4dvarnet/utils.py
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
def load_full_natl_data(
    path_obs="../sla-data-registry/CalData/cal_data_new_errs.nc",
    path_gt="../sla-data-registry/NATL60/NATL/ref_new/NATL60-CJM165_NATL_ssh_y2013.1y.nc",
    obs_var='five_nadirs',
    gt_var='ssh',
    **kwargs
):
    """
    Load and preprocess the full NATL dataset.

    Args:
        path_obs (str): Path to the observation dataset.
        path_gt (str): Path to the ground truth dataset.
        obs_var (str): Observation variable name.
        gt_var (str): Ground truth variable name.

    Returns:
        xarray.DataArray: The preprocessed NATL dataset.
    """
    inp = xr.open_dataset(path_obs)[obs_var]
    gt = (
        xr.open_dataset(path_gt)[gt_var]
        # .isel(time=slice(0, -1))
        .sel(lat=inp.lat, lon=inp.lon, method="nearest")
    )

    return xr.Dataset(dict(input=inp, tgt=(gt.dims, gt.values)), inp.coords).to_array().sortby('variable')

pipe(inp, fns)

Apply a sequence of functions to an input.

Parameters:

Name Type Description Default
inp

The input to process.

required
fns list

A list of functions to apply.

required

Returns:

Type Description

The processed input after applying all functions.

Source code in ocean4dvarnet/utils.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def pipe(inp, fns):
    """
    Apply a sequence of functions to an input.

    Args:
        inp: The input to process.
        fns (list): A list of functions to apply.

    Returns:
        The processed input after applying all functions.
    """
    for f in fns:
        inp = f(inp)
    return inp

psd_based_scores(da_rec, da_ref)

Compute PSD-based scores for reconstruction evaluation.

Parameters:

Name Type Description Default
da_rec DataArray

The reconstructed data.

required
da_ref DataArray

The reference data.

required

Returns:

Name Type Description
tuple

A tuple containing PSD-based scores and resolved wavelengths.

Source code in ocean4dvarnet/utils.py
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
def psd_based_scores(da_rec, da_ref):
    """
    Compute PSD-based scores for reconstruction evaluation.

    Args:
        da_rec (xarray.DataArray): The reconstructed data.
        da_ref (xarray.DataArray): The reference data.

    Returns:
        tuple: A tuple containing PSD-based scores and resolved wavelengths.
    """
    err = da_rec - da_ref
    err["time"] = (err.time - err.time[0]) / np.timedelta64(1, "D")
    signal = da_ref
    signal["time"] = (signal.time - signal.time[0]) / np.timedelta64(1, "D")
    psd_err = xrft.power_spectrum(
        err, dim=["time", "lon"], detrend="constant", window="hann"
    ).compute()
    psd_signal = xrft.power_spectrum(
        signal, dim=["time", "lon"], detrend="constant", window="hann"
    ).compute()
    mean_psd_signal = psd_signal.mean(dim="lat").where(
        (psd_signal.freq_lon > 0.0) & (psd_signal.freq_time > 0), drop=True
    )
    mean_psd_err = psd_err.mean(dim="lat").where(
        (psd_err.freq_lon > 0.0) & (psd_err.freq_time > 0), drop=True
    )
    psd_based_score = 1.0 - mean_psd_err / mean_psd_signal
    level = [0.5]
    cs = plt.contour(
        1.0 / psd_based_score.freq_lon.values,
        1.0 / psd_based_score.freq_time.values,
        psd_based_score,
        level,
    )
    x05, y05 = cs.collections[0].get_paths()[0].vertices.T
    plt.close()

    shortest_spatial_wavelength_resolved = np.min(x05)
    shortest_temporal_wavelength_resolved = np.min(y05)
    psd_da = 1.0 - mean_psd_err / mean_psd_signal
    psd_da.name = "psd_score"
    return (
        psd_da.to_dataset(),
        np.round(shortest_spatial_wavelength_resolved, 3).item(),
        np.round(shortest_temporal_wavelength_resolved, 3).item(),
    )

psd_based_scores_from_ds(ds, ref_variable='tgt', study_variable='out')

Compute PSD-based scores from a dataset.

Parameters:

Name Type Description Default
ds Dataset

The dataset containing the reference and study variables.

required
ref_variable str

The name of the reference variable.

'tgt'
study_variable str

The name of the study variable.

'out'

Returns:

Name Type Description
list

A list containing PSD-based scores.

Source code in ocean4dvarnet/utils.py
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
def psd_based_scores_from_ds(ds, ref_variable='tgt', study_variable='out'):
    """
    Compute PSD-based scores from a dataset.

    Args:
        ds (xarray.Dataset): The dataset containing the reference and study variables.
        ref_variable (str): The name of the reference variable.
        study_variable (str): The name of the study variable.

    Returns:
        list: A list containing PSD-based scores.
    """
    try:
        return psd_based_scores(ds[study_variable], ds[ref_variable])[1:]
    except Exception:
        return [np.nan, np.nan]

remove_nan(da)

Fill NaN values in a DataArray using Gauss-Seidel interpolation.

Parameters:

Name Type Description Default
da DataArray

The input DataArray.

required

Returns:

Type Description

xarray.DataArray: The DataArray with NaN values filled.

Source code in ocean4dvarnet/utils.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def remove_nan(da):
    """
    Fill NaN values in a DataArray using Gauss-Seidel interpolation.

    Args:
        da (xarray.DataArray): The input DataArray.

    Returns:
        xarray.DataArray: The DataArray with NaN values filled.
    """
    da["lon"] = da.lon.assign_attrs(units="degrees_east")
    da["lat"] = da.lat.assign_attrs(units="degrees_north")

    da.transpose("lon", "lat", "time")[:, :] = pyinterp.fill.gauss_seidel(
        pyinterp.backends.xarray.Grid3D(da)
    )[1]
    return da

rmse_based_scores(da_rec, da_ref)

Compute RMSE-based scores for reconstruction evaluation.

Parameters:

Name Type Description Default
da_rec DataArray

The reconstructed data.

required
da_ref DataArray

The reference data.

required

Returns:

Name Type Description
tuple

A tuple containing RMSE-based scores.

Source code in ocean4dvarnet/utils.py
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
def rmse_based_scores(da_rec, da_ref):
    """
    Compute RMSE-based scores for reconstruction evaluation.

    Args:
        da_rec (xarray.DataArray): The reconstructed data.
        da_ref (xarray.DataArray): The reference data.

    Returns:
        tuple: A tuple containing RMSE-based scores.
    """
    rmse_t = (
        1.0
        - (((da_rec - da_ref) ** 2).mean(dim=("lon", "lat"))) ** 0.5
        / (((da_ref) ** 2).mean(dim=("lon", "lat"))) ** 0.5
    )
    rmse_xy = (((da_rec - da_ref) ** 2).mean(dim=("time"))) ** 0.5
    rmse_t = rmse_t.rename("rmse_t")
    rmse_xy = rmse_xy.rename("rmse_xy")
    reconstruction_error_stability_metric = rmse_t.std().values
    leaderboard_rmse = (
        1.0 - (((da_rec - da_ref) ** 2).mean()) ** 0.5 / (((da_ref) ** 2).mean()) ** 0.5
    )
    return (
        rmse_t,
        rmse_xy,
        np.round(leaderboard_rmse.values, 5).item(),
        np.round(reconstruction_error_stability_metric, 5).item(),
    )

rmse_based_scores_from_ds(ds, ref_variable='tgt', study_variable='out')

Compute RMSE-based scores from a dataset.

Parameters:

Name Type Description Default
ds Dataset

The dataset containing the reference and study variables.

required
ref_variable str

The name of the reference variable.

'tgt'
study_variable str

The name of the study variable.

'out'

Returns:

Name Type Description
list

A list containing RMSE-based scores.

Source code in ocean4dvarnet/utils.py
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
def rmse_based_scores_from_ds(ds, ref_variable='tgt', study_variable='out'):
    """
    Compute RMSE-based scores from a dataset.

    Args:
        ds (xarray.Dataset): The dataset containing the reference and study variables.
        ref_variable (str): The name of the reference variable.
        study_variable (str): The name of the study variable.

    Returns:
        list: A list containing RMSE-based scores.
    """
    try:
        return rmse_based_scores(ds[study_variable], ds[ref_variable])[2:]
    except Exception:
        return [np.nan, np.nan]

test_osse(trainer, lit_mod, osse_dm, osse_test_domain, ckpt, diag_data_dir=None)

Perform OSSE (Observing System Simulation Experiment) testing and compute metrics.

Parameters:

Name Type Description Default
trainer Trainer

The PyTorch Lightning trainer instance.

required
lit_mod LightningModule

The Lightning module to test.

required
osse_dm LightningDataModule

The datamodule for OSSE testing.

required
osse_test_domain dict

The test domain for evaluation.

required
ckpt str

Path to the checkpoint to load.

required
diag_data_dir Path

Directory to save diagnostic data.

None

Returns:

Type Description

pandas.Series: A series containing OSSE metrics.

Source code in ocean4dvarnet/utils.py
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
def test_osse(trainer, lit_mod, osse_dm, osse_test_domain, ckpt, diag_data_dir=None):
    """
    Perform OSSE (Observing System Simulation Experiment) testing and compute metrics.

    Args:
        trainer (pl.Trainer): The PyTorch Lightning trainer instance.
        lit_mod (pl.LightningModule): The Lightning module to test.
        osse_dm (pl.LightningDataModule): The datamodule for OSSE testing.
        osse_test_domain (dict): The test domain for evaluation.
        ckpt (str): Path to the checkpoint to load.
        diag_data_dir (Path, optional): Directory to save diagnostic data.

    Returns:
        pandas.Series: A series containing OSSE metrics.
    """
    lit_mod.norm_stats = osse_dm.norm_stats()
    trainer.test(lit_mod, datamodule=osse_dm, ckpt_path=ckpt)
    osse_tdat = lit_mod.test_data[['out', 'ssh']]
    osse_metrics = diagnostics_from_ds(
        osse_tdat, test_domain=osse_test_domain
    )

    print(osse_metrics.to_markdown())

    if diag_data_dir is not None:
        osse_metrics.to_csv(diag_data_dir / "osse_metrics.csv")
        if (diag_data_dir / "osse_test_data.nc").exists():
            xr.open_dataset(diag_data_dir / "osse_test_data.nc").close()
        osse_tdat.to_netcdf(diag_data_dir / "osse_test_data.nc")

    return osse_metrics

triang_lr_adam(lit_mod, lr_min=5e-05, lr_max=0.003, nsteps=200)

Configure an Adam optimizer with triangular cyclic learning rate scheduling.

Parameters:

Name Type Description Default
lit_mod

The Lightning module containing the model.

required
lr_min float

Minimum learning rate.

5e-05
lr_max float

Maximum learning rate.

0.003
nsteps int

Number of steps for the triangular cycle.

200

Returns:

Name Type Description
dict

A dictionary containing the optimizer and scheduler.

Source code in ocean4dvarnet/utils.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def triang_lr_adam(lit_mod, lr_min=5e-5, lr_max=3e-3, nsteps=200):
    """
    Configure an Adam optimizer with triangular cyclic learning rate scheduling.

    Args:
        lit_mod: The Lightning module containing the model.
        lr_min (float): Minimum learning rate.
        lr_max (float): Maximum learning rate.
        nsteps (int): Number of steps for the triangular cycle.

    Returns:
        dict: A dictionary containing the optimizer and scheduler.
    """
    opt = torch.optim.Adam(
        [
            {"params": lit_mod.solver.grad_mod.parameters(), "lr": lr_max},
            {"params": lit_mod.solver.prior_cost.parameters(), "lr": lr_max / 2},
        ],
    )
    return {
        "optimizer": opt,
        "lr_scheduler": torch.optim.lr_scheduler.CyclicLR(
            opt,
            base_lr=lr_min,
            max_lr=lr_max,
            step_size_up=nsteps,
            step_size_down=nsteps,
            gamma=0.95,
            cycle_momentum=False,
            mode="exp_range",
        ),
    }

vort(da)

Compute the vorticity from a DataArray.

Parameters:

Name Type Description Default
da DataArray

The input DataArray.

required

Returns:

Type Description

xarray.DataArray: The vorticity computed from the input data.

Source code in ocean4dvarnet/utils.py
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
def vort(da):
    """
    Compute the vorticity from a DataArray.

    Args:
        da (xarray.DataArray): The input DataArray.

    Returns:
        xarray.DataArray: The vorticity computed from the input data.
    """
    return mpcalc.vorticity(
        *mpcalc.geostrophic_wind(
            da.pipe(add_geo_attrs).assign_attrs(units="m").metpy.quantify()
        )
    ).metpy.dequantify()