Skip to content

data

This module provides data handling utilities for 4D-VarNet models.

It includes classes and functions for creating datasets, augmenting data, managing data loading pipelines, and reconstructing data from patches. These utilities are designed to work seamlessly with PyTorch and xarray, enabling efficient data preprocessing and loading for machine learning tasks.

Classes:

Name Description
- XrDataset

A PyTorch Dataset for extracting patches from xarray.DataArray objects.

- XrConcatDataset

A concatenation of multiple XrDatasets.

- AugmentedDataset

A dataset wrapper for applying data augmentation.

- BaseDataModule

A PyTorch Lightning DataModule for managing datasets and data loaders.

- ConcatDataModule

A DataModule for combining datasets from multiple domains.

- RandValDataModule

A DataModule for random splitting of training data into training and validation sets.

Raises:

Type Description
-IncompleteScanConfiguration

Raised when the scan configuration does not cover the entire domain.

-DangerousDimOrdering

Raised when the dimension ordering of the input data is incorrect.

Key Features
  • Patch extraction: Efficiently extract patches from large xarray.DataArray objects for training.
  • Data augmentation: Support for augmenting datasets with noise and transformations.
  • Reconstruction: Reconstruct the original data from extracted patches.
  • Seamless integration: Designed to work with PyTorch Lightning for streamlined training pipelines.

AugmentedDataset

Bases: Dataset

A dataset that applies data augmentation to an input dataset.

Attributes:

Name Type Description
inp_ds Dataset

The input dataset.

aug_factor int

The number of augmented copies to generate.

aug_only bool

Whether to include only augmented data.

noise_sigma float

Standard deviation of noise to add to augmented data.

Source code in ocean4dvarnet/data.py
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
class AugmentedDataset(torch.utils.data.Dataset):
    """
    A dataset that applies data augmentation to an input dataset.

    Attributes:
        inp_ds (torch.utils.data.Dataset): The input dataset.
        aug_factor (int): The number of augmented copies to generate.
        aug_only (bool): Whether to include only augmented data.
        noise_sigma (float): Standard deviation of noise to add to augmented data.
    """

    def __init__(self, inp_ds, aug_factor, aug_only=False, noise_sigma=None):
        """
        Initialize the AugmentedDataset.

        Args:
            inp_ds (torch.utils.data.Dataset): The input dataset.
            aug_factor (int): The number of augmented copies to generate.
            aug_only (bool, optional): Whether to include only augmented data.
            noise_sigma (float, optional): Standard deviation of noise to add to augmented data.
        """
        self.aug_factor = aug_factor
        self.aug_only = aug_only
        self.inp_ds = inp_ds
        self.perm = np.random.permutation(len(self.inp_ds))
        self.noise_sigma = noise_sigma

    def __len__(self):
        """
        Return the total number of items in the dataset.

        Returns:
            int: Total number of items.
        """
        return len(self.inp_ds) * (1 + self.aug_factor - int(self.aug_only))

    def __getitem__(self, idx):
        """
        Get an item from the dataset.

        Args:
            idx (int): Index of the item.

        Returns:
            TrainingItem: The requested item.
        """
        if self.aug_only:
            idx = idx + len(self.inp_ds)

        if idx < len(self.inp_ds):
            return self.inp_ds[idx]

        tgt_idx = idx % len(self.inp_ds)
        perm_idx = tgt_idx
        for _ in range(idx // len(self.inp_ds)):
            perm_idx = self.perm[perm_idx]

        item = self.inp_ds[tgt_idx]
        perm_item = self.inp_ds[perm_idx]

        noise = np.zeros_like(item.input, dtype=np.float32)
        if self.noise_sigma is not None:
            noise = np.random.randn(*item.input.shape).astype(np.float32) * self.noise_sigma

        return item._replace(input=noise + np.where(np.isfinite(perm_item.input),
                             item.tgt, np.full_like(item.tgt, np.nan)))

__getitem__(idx)

Get an item from the dataset.

Parameters:

Name Type Description Default
idx int

Index of the item.

required

Returns:

Name Type Description
TrainingItem

The requested item.

Source code in ocean4dvarnet/data.py
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
def __getitem__(self, idx):
    """
    Get an item from the dataset.

    Args:
        idx (int): Index of the item.

    Returns:
        TrainingItem: The requested item.
    """
    if self.aug_only:
        idx = idx + len(self.inp_ds)

    if idx < len(self.inp_ds):
        return self.inp_ds[idx]

    tgt_idx = idx % len(self.inp_ds)
    perm_idx = tgt_idx
    for _ in range(idx // len(self.inp_ds)):
        perm_idx = self.perm[perm_idx]

    item = self.inp_ds[tgt_idx]
    perm_item = self.inp_ds[perm_idx]

    noise = np.zeros_like(item.input, dtype=np.float32)
    if self.noise_sigma is not None:
        noise = np.random.randn(*item.input.shape).astype(np.float32) * self.noise_sigma

    return item._replace(input=noise + np.where(np.isfinite(perm_item.input),
                         item.tgt, np.full_like(item.tgt, np.nan)))

__init__(inp_ds, aug_factor, aug_only=False, noise_sigma=None)

Initialize the AugmentedDataset.

Parameters:

Name Type Description Default
inp_ds Dataset

The input dataset.

required
aug_factor int

The number of augmented copies to generate.

required
aug_only bool

Whether to include only augmented data.

False
noise_sigma float

Standard deviation of noise to add to augmented data.

None
Source code in ocean4dvarnet/data.py
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
def __init__(self, inp_ds, aug_factor, aug_only=False, noise_sigma=None):
    """
    Initialize the AugmentedDataset.

    Args:
        inp_ds (torch.utils.data.Dataset): The input dataset.
        aug_factor (int): The number of augmented copies to generate.
        aug_only (bool, optional): Whether to include only augmented data.
        noise_sigma (float, optional): Standard deviation of noise to add to augmented data.
    """
    self.aug_factor = aug_factor
    self.aug_only = aug_only
    self.inp_ds = inp_ds
    self.perm = np.random.permutation(len(self.inp_ds))
    self.noise_sigma = noise_sigma

__len__()

Return the total number of items in the dataset.

Returns:

Name Type Description
int

Total number of items.

Source code in ocean4dvarnet/data.py
307
308
309
310
311
312
313
314
def __len__(self):
    """
    Return the total number of items in the dataset.

    Returns:
        int: Total number of items.
    """
    return len(self.inp_ds) * (1 + self.aug_factor - int(self.aug_only))

BaseDataModule

Bases: LightningDataModule

A base data module for managing datasets and data loaders in PyTorch Lightning.

Attributes:

Name Type Description
input_da DataArray

The input data array.

domains dict

Dictionary of domain splits (train, val, test).

xrds_kw dict

Keyword arguments for XrDataset.

dl_kw dict

Keyword arguments for DataLoader.

aug_kw dict

Keyword arguments for AugmentedDataset.

norm_stats tuple

Normalization statistics (mean, std).

Source code in ocean4dvarnet/data.py
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
class BaseDataModule(pl.LightningDataModule):
    """
    A base data module for managing datasets and data loaders in PyTorch Lightning.

    Attributes:
        input_da (xarray.DataArray): The input data array.
        domains (dict): Dictionary of domain splits (train, val, test).
        xrds_kw (dict): Keyword arguments for XrDataset.
        dl_kw (dict): Keyword arguments for DataLoader.
        aug_kw (dict): Keyword arguments for AugmentedDataset.
        norm_stats (tuple): Normalization statistics (mean, std).
    """

    def __init__(self, input_da, domains, xrds_kw, dl_kw, aug_kw=None, norm_stats=None, **kwargs):
        """
        Initialize the BaseDataModule.

        Args:
            input_da (xarray.DataArray): The input data array.
            domains (dict): Dictionary of domain splits (train, val, test).
            xrds_kw (dict): Keyword arguments for XrDataset.
            dl_kw (dict): Keyword arguments for DataLoader.
            aug_kw (dict, optional): Keyword arguments for AugmentedDataset.
            norm_stats (tuple, optional): Normalization statistics (mean, std).
        """
        super().__init__()
        self.input_da = input_da
        self.domains = domains
        self.xrds_kw = xrds_kw
        self.dl_kw = dl_kw
        self.aug_kw = aug_kw if aug_kw is not None else {}
        self._norm_stats = norm_stats

        self.train_ds = None
        self.val_ds = None
        self.test_ds = None
        self._post_fn = None

    def norm_stats(self):
        """
        Compute or retrieve normalization statistics (mean, std).

        Returns:
            tuple: Normalization statistics (mean, std).
        """
        if self._norm_stats is None:
            self._norm_stats = self.train_mean_std()
            print("Norm stats", self._norm_stats)
        return self._norm_stats

    def train_mean_std(self, variable='tgt'):
        """
        Compute the mean and standard deviation of the training data.

        Args:
            variable (str, optional): Variable to compute statistics for.

        Returns:
            tuple: Mean and standard deviation.
        """
        train_data = self.input_da.sel(self.xrds_kw.get('domain_limits', {})).sel(self.domains['train'])
        return train_data.sel(variable=variable).pipe(lambda da: (da.mean().values.item(), da.std().values.item()))

    def post_fn(self):
        """
        Create a post-processing function for normalizing data.

        Returns:
            callable: Post-processing function.
        """
        m, s = self.norm_stats()
        def normalize(item): return (item - m) / s
        return ft.partial(ft.reduce, lambda i, f: f(i), [
            TrainingItem._make,
            lambda item: item._replace(tgt=normalize(item.tgt)),
            lambda item: item._replace(input=normalize(item.input)),
        ])

    def setup(self, stage='test'):
        """
        Set up the datasets for training, validation, and testing.

        Args:
            stage (str, optional): Stage of the setup ('train', 'val', 'test').
        """
        train_data = self.input_da.sel(self.domains['train'])
        post_fn = self.post_fn()
        self.train_ds = XrDataset(
            train_data, **self.xrds_kw, postpro_fn=post_fn,
        )
        if self.aug_kw:
            self.train_ds = AugmentedDataset(self.train_ds, **self.aug_kw)

        self.val_ds = XrDataset(
            self.input_da.sel(self.domains['val']), **self.xrds_kw, postpro_fn=post_fn,
        )
        self.test_ds = XrDataset(
            self.input_da.sel(self.domains['test']), **self.xrds_kw, postpro_fn=post_fn,
        )

    def train_dataloader(self):
        """
        Create a DataLoader for the training dataset.

        Returns:
            DataLoader: Training DataLoader.
        """
        return torch.utils.data.DataLoader(self.train_ds, shuffle=True, **self.dl_kw)

    def val_dataloader(self):
        """
        Create a DataLoader for the validation dataset.

        Returns:
            DataLoader: Validation DataLoader.
        """
        return torch.utils.data.DataLoader(self.val_ds, shuffle=False, **self.dl_kw)

    def test_dataloader(self):
        """
        Create a DataLoader for the testing dataset.

        Returns:
            DataLoader: Testing DataLoader.
        """
        return torch.utils.data.DataLoader(self.test_ds, shuffle=False, **self.dl_kw)

__init__(input_da, domains, xrds_kw, dl_kw, aug_kw=None, norm_stats=None, **kwargs)

Initialize the BaseDataModule.

Parameters:

Name Type Description Default
input_da DataArray

The input data array.

required
domains dict

Dictionary of domain splits (train, val, test).

required
xrds_kw dict

Keyword arguments for XrDataset.

required
dl_kw dict

Keyword arguments for DataLoader.

required
aug_kw dict

Keyword arguments for AugmentedDataset.

None
norm_stats tuple

Normalization statistics (mean, std).

None
Source code in ocean4dvarnet/data.py
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
def __init__(self, input_da, domains, xrds_kw, dl_kw, aug_kw=None, norm_stats=None, **kwargs):
    """
    Initialize the BaseDataModule.

    Args:
        input_da (xarray.DataArray): The input data array.
        domains (dict): Dictionary of domain splits (train, val, test).
        xrds_kw (dict): Keyword arguments for XrDataset.
        dl_kw (dict): Keyword arguments for DataLoader.
        aug_kw (dict, optional): Keyword arguments for AugmentedDataset.
        norm_stats (tuple, optional): Normalization statistics (mean, std).
    """
    super().__init__()
    self.input_da = input_da
    self.domains = domains
    self.xrds_kw = xrds_kw
    self.dl_kw = dl_kw
    self.aug_kw = aug_kw if aug_kw is not None else {}
    self._norm_stats = norm_stats

    self.train_ds = None
    self.val_ds = None
    self.test_ds = None
    self._post_fn = None

norm_stats()

Compute or retrieve normalization statistics (mean, std).

Returns:

Name Type Description
tuple

Normalization statistics (mean, std).

Source code in ocean4dvarnet/data.py
386
387
388
389
390
391
392
393
394
395
396
def norm_stats(self):
    """
    Compute or retrieve normalization statistics (mean, std).

    Returns:
        tuple: Normalization statistics (mean, std).
    """
    if self._norm_stats is None:
        self._norm_stats = self.train_mean_std()
        print("Norm stats", self._norm_stats)
    return self._norm_stats

post_fn()

Create a post-processing function for normalizing data.

Returns:

Name Type Description
callable

Post-processing function.

Source code in ocean4dvarnet/data.py
411
412
413
414
415
416
417
418
419
420
421
422
423
424
def post_fn(self):
    """
    Create a post-processing function for normalizing data.

    Returns:
        callable: Post-processing function.
    """
    m, s = self.norm_stats()
    def normalize(item): return (item - m) / s
    return ft.partial(ft.reduce, lambda i, f: f(i), [
        TrainingItem._make,
        lambda item: item._replace(tgt=normalize(item.tgt)),
        lambda item: item._replace(input=normalize(item.input)),
    ])

setup(stage='test')

Set up the datasets for training, validation, and testing.

Parameters:

Name Type Description Default
stage str

Stage of the setup ('train', 'val', 'test').

'test'
Source code in ocean4dvarnet/data.py
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
def setup(self, stage='test'):
    """
    Set up the datasets for training, validation, and testing.

    Args:
        stage (str, optional): Stage of the setup ('train', 'val', 'test').
    """
    train_data = self.input_da.sel(self.domains['train'])
    post_fn = self.post_fn()
    self.train_ds = XrDataset(
        train_data, **self.xrds_kw, postpro_fn=post_fn,
    )
    if self.aug_kw:
        self.train_ds = AugmentedDataset(self.train_ds, **self.aug_kw)

    self.val_ds = XrDataset(
        self.input_da.sel(self.domains['val']), **self.xrds_kw, postpro_fn=post_fn,
    )
    self.test_ds = XrDataset(
        self.input_da.sel(self.domains['test']), **self.xrds_kw, postpro_fn=post_fn,
    )

test_dataloader()

Create a DataLoader for the testing dataset.

Returns:

Name Type Description
DataLoader

Testing DataLoader.

Source code in ocean4dvarnet/data.py
466
467
468
469
470
471
472
473
def test_dataloader(self):
    """
    Create a DataLoader for the testing dataset.

    Returns:
        DataLoader: Testing DataLoader.
    """
    return torch.utils.data.DataLoader(self.test_ds, shuffle=False, **self.dl_kw)

train_dataloader()

Create a DataLoader for the training dataset.

Returns:

Name Type Description
DataLoader

Training DataLoader.

Source code in ocean4dvarnet/data.py
448
449
450
451
452
453
454
455
def train_dataloader(self):
    """
    Create a DataLoader for the training dataset.

    Returns:
        DataLoader: Training DataLoader.
    """
    return torch.utils.data.DataLoader(self.train_ds, shuffle=True, **self.dl_kw)

train_mean_std(variable='tgt')

Compute the mean and standard deviation of the training data.

Parameters:

Name Type Description Default
variable str

Variable to compute statistics for.

'tgt'

Returns:

Name Type Description
tuple

Mean and standard deviation.

Source code in ocean4dvarnet/data.py
398
399
400
401
402
403
404
405
406
407
408
409
def train_mean_std(self, variable='tgt'):
    """
    Compute the mean and standard deviation of the training data.

    Args:
        variable (str, optional): Variable to compute statistics for.

    Returns:
        tuple: Mean and standard deviation.
    """
    train_data = self.input_da.sel(self.xrds_kw.get('domain_limits', {})).sel(self.domains['train'])
    return train_data.sel(variable=variable).pipe(lambda da: (da.mean().values.item(), da.std().values.item()))

val_dataloader()

Create a DataLoader for the validation dataset.

Returns:

Name Type Description
DataLoader

Validation DataLoader.

Source code in ocean4dvarnet/data.py
457
458
459
460
461
462
463
464
def val_dataloader(self):
    """
    Create a DataLoader for the validation dataset.

    Returns:
        DataLoader: Validation DataLoader.
    """
    return torch.utils.data.DataLoader(self.val_ds, shuffle=False, **self.dl_kw)

ConcatDataModule

Bases: BaseDataModule

A data module for concatenating datasets from multiple domains.

Source code in ocean4dvarnet/data.py
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
class ConcatDataModule(BaseDataModule):
    """A data module for concatenating datasets from multiple domains."""

    def train_mean_std(self):
        """
        Compute the mean and standard deviation of the training data across domains.

        Returns:
            tuple: Mean and standard deviation.
        """
        sum, count = 0, 0
        train_data = self.input_da.sel(self.xrds_kw.get('domain_limits', {}))
        for domain in self.domains['train']:
            _sum, _count = train_data.sel(domain).sel(variable='tgt').pipe(
                lambda da: (da.sum(), da.pipe(np.isfinite).sum())
            )
            sum += _sum
            count += _count

        mean = sum / count
        sum = 0
        for domain in self.domains['train']:
            _sum = train_data.sel(domain).sel(variable='tgt').pipe(lambda da: da - mean).pipe(np.square).sum()
            sum += _sum
        std = (sum / count)**0.5
        return mean.values.item(), std.values.item()

    def setup(self, stage='test'):
        """
        Set up the datasets for training, validation, and testing.

        Args:
            stage (str, optional): Stage of the setup ('train', 'val', 'test').
        """
        post_fn = self.post_fn()
        self.train_ds = XrConcatDataset([
            XrDataset(self.input_da.sel(domain), **self.xrds_kw, postpro_fn=post_fn,)
            for domain in self.domains['train']
        ])
        if self.aug_factor >= 1:
            self.train_ds = AugmentedDataset(self.train_ds, **self.aug_kw)

        self.val_ds = XrConcatDataset([
            XrDataset(self.input_da.sel(domain), **self.xrds_kw, postpro_fn=post_fn,)
            for domain in self.domains['val']
        ])
        self.test_ds = XrConcatDataset([
            XrDataset(self.input_da.sel(domain), **self.xrds_kw, postpro_fn=post_fn,)
            for domain in self.domains['test']
        ])

setup(stage='test')

Set up the datasets for training, validation, and testing.

Parameters:

Name Type Description Default
stage str

Stage of the setup ('train', 'val', 'test').

'test'
Source code in ocean4dvarnet/data.py
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
def setup(self, stage='test'):
    """
    Set up the datasets for training, validation, and testing.

    Args:
        stage (str, optional): Stage of the setup ('train', 'val', 'test').
    """
    post_fn = self.post_fn()
    self.train_ds = XrConcatDataset([
        XrDataset(self.input_da.sel(domain), **self.xrds_kw, postpro_fn=post_fn,)
        for domain in self.domains['train']
    ])
    if self.aug_factor >= 1:
        self.train_ds = AugmentedDataset(self.train_ds, **self.aug_kw)

    self.val_ds = XrConcatDataset([
        XrDataset(self.input_da.sel(domain), **self.xrds_kw, postpro_fn=post_fn,)
        for domain in self.domains['val']
    ])
    self.test_ds = XrConcatDataset([
        XrDataset(self.input_da.sel(domain), **self.xrds_kw, postpro_fn=post_fn,)
        for domain in self.domains['test']
    ])

train_mean_std()

Compute the mean and standard deviation of the training data across domains.

Returns:

Name Type Description
tuple

Mean and standard deviation.

Source code in ocean4dvarnet/data.py
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
def train_mean_std(self):
    """
    Compute the mean and standard deviation of the training data across domains.

    Returns:
        tuple: Mean and standard deviation.
    """
    sum, count = 0, 0
    train_data = self.input_da.sel(self.xrds_kw.get('domain_limits', {}))
    for domain in self.domains['train']:
        _sum, _count = train_data.sel(domain).sel(variable='tgt').pipe(
            lambda da: (da.sum(), da.pipe(np.isfinite).sum())
        )
        sum += _sum
        count += _count

    mean = sum / count
    sum = 0
    for domain in self.domains['train']:
        _sum = train_data.sel(domain).sel(variable='tgt').pipe(lambda da: da - mean).pipe(np.square).sum()
        sum += _sum
    std = (sum / count)**0.5
    return mean.values.item(), std.values.item()

DangerousDimOrdering

Bases: Exception

Exception raised when the dimension ordering of the input data is incorrect.

Source code in ocean4dvarnet/data.py
44
45
46
class DangerousDimOrdering(Exception):
    """Exception raised when the dimension ordering of the input data is incorrect."""
    pass

IncompleteScanConfiguration

Bases: Exception

Exception raised when the scan configuration does not cover the entire domain.

Source code in ocean4dvarnet/data.py
39
40
41
class IncompleteScanConfiguration(Exception):
    """Exception raised when the scan configuration does not cover the entire domain."""
    pass

RandValDataModule

Bases: BaseDataModule

A data module that randomly splits the training data into training and validation sets.

Attributes:

Name Type Description
val_prop float

Proportion of data to use for validation.

Source code in ocean4dvarnet/data.py
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
class RandValDataModule(BaseDataModule):
    """
    A data module that randomly splits the training data into training and validation sets.

    Attributes:
        val_prop (float): Proportion of data to use for validation.
    """

    def __init__(self, val_prop, *args, **kwargs):
        """
        Initialize the RandValDataModule.

        Args:
            val_prop (float): Proportion of data to use for validation.
        """
        super().__init__(*args, **kwargs)
        self.val_prop = val_prop

    def setup(self, stage='test'):
        """
        Set up the datasets for training, validation, and testing.

        Args:
            stage (str, optional): Stage of the setup ('train', 'val', 'test').
        """
        post_fn = self.post_fn()
        train_ds = XrDataset(self.input_da.sel(self.domains['train']), **self.xrds_kw, postpro_fn=post_fn,)
        n_val = int(self.val_prop * len(train_ds))
        n_train = len(train_ds) - n_val
        self.train_ds, self.val_ds = torch.utils.data.random_split(train_ds, [n_train, n_val])

        if self.aug_factor > 1:
            self.train_ds = AugmentedDataset(self.train_ds, **self.aug_kw)

        self.test_ds = XrDataset(self.input_da.sel(self.domains['test']), **self.xrds_kw, postpro_fn=post_fn,)

__init__(val_prop, *args, **kwargs)

Initialize the RandValDataModule.

Parameters:

Name Type Description Default
val_prop float

Proportion of data to use for validation.

required
Source code in ocean4dvarnet/data.py
536
537
538
539
540
541
542
543
544
def __init__(self, val_prop, *args, **kwargs):
    """
    Initialize the RandValDataModule.

    Args:
        val_prop (float): Proportion of data to use for validation.
    """
    super().__init__(*args, **kwargs)
    self.val_prop = val_prop

setup(stage='test')

Set up the datasets for training, validation, and testing.

Parameters:

Name Type Description Default
stage str

Stage of the setup ('train', 'val', 'test').

'test'
Source code in ocean4dvarnet/data.py
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
def setup(self, stage='test'):
    """
    Set up the datasets for training, validation, and testing.

    Args:
        stage (str, optional): Stage of the setup ('train', 'val', 'test').
    """
    post_fn = self.post_fn()
    train_ds = XrDataset(self.input_da.sel(self.domains['train']), **self.xrds_kw, postpro_fn=post_fn,)
    n_val = int(self.val_prop * len(train_ds))
    n_train = len(train_ds) - n_val
    self.train_ds, self.val_ds = torch.utils.data.random_split(train_ds, [n_train, n_val])

    if self.aug_factor > 1:
        self.train_ds = AugmentedDataset(self.train_ds, **self.aug_kw)

    self.test_ds = XrDataset(self.input_da.sel(self.domains['test']), **self.xrds_kw, postpro_fn=post_fn,)

XrConcatDataset

Bases: ConcatDataset

A concatenation of multiple XrDatasets.

This class allows combining multiple datasets into one for training or evaluation.

Source code in ocean4dvarnet/data.py
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
class XrConcatDataset(torch.utils.data.ConcatDataset):
    """
    A concatenation of multiple XrDatasets.

    This class allows combining multiple datasets into one for training or evaluation.
    """

    def reconstruct(self, batches, weight=None):
        """
        Reconstruct the original data arrays from batches.

        Args:
            batches (list): List of batches.
            weight (np.ndarray, optional): Weighting for overlapping patches.

        Returns:
            list: List of reconstructed xarray.DataArray objects.
        """
        items_iter = itertools.chain(*batches)
        rec_das = []
        for ds in self.datasets:
            ds_items = list(itertools.islice(items_iter, len(ds)))
            rec_das.append(ds.reconstruct_from_items(ds_items, weight))

        return rec_das

reconstruct(batches, weight=None)

Reconstruct the original data arrays from batches.

Parameters:

Name Type Description Default
batches list

List of batches.

required
weight ndarray

Weighting for overlapping patches.

None

Returns:

Name Type Description
list

List of reconstructed xarray.DataArray objects.

Source code in ocean4dvarnet/data.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
def reconstruct(self, batches, weight=None):
    """
    Reconstruct the original data arrays from batches.

    Args:
        batches (list): List of batches.
        weight (np.ndarray, optional): Weighting for overlapping patches.

    Returns:
        list: List of reconstructed xarray.DataArray objects.
    """
    items_iter = itertools.chain(*batches)
    rec_das = []
    for ds in self.datasets:
        ds_items = list(itertools.islice(items_iter, len(ds)))
        rec_das.append(ds.reconstruct_from_items(ds_items, weight))

    return rec_das

XrDataset

Bases: Dataset

A PyTorch Dataset based on an xarray.DataArray with on-the-fly slicing.

This class allows efficient extraction of patches from an xarray.DataArray for training machine learning models.

Usage

If you want to be able to reconstruct the input, the input xr.DataArray should: - Have coordinates. - Have the last dims correspond to the patch dims in the same order. - Have, for each dim of patch_dim, (size(dim) - patch_dim(dim)) divisible by stride(dim).

The batches passed to self.reconstruct should: - Have the last dims correspond to the patch dims in the same order.

Attributes:

Name Type Description
da DataArray

The input data array.

patch_dims dict

Dimensions and sizes of patches to extract.

domain_limits dict

Limits for selecting a subset of the domain.

strides dict

Strides for patch extraction.

check_full_scan bool

Whether to check if the entire domain is scanned.

check_dim_order bool

Whether to check the dimension ordering.

postpro_fn callable

A function for post-processing extracted patches.

Source code in ocean4dvarnet/data.py
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
class XrDataset(torch.utils.data.Dataset):
    """
    A PyTorch Dataset based on an xarray.DataArray with on-the-fly slicing.

    This class allows efficient extraction of patches from an xarray.DataArray
    for training machine learning models.

    Usage:
        If you want to be able to reconstruct the input, the input xr.DataArray should:
        - Have coordinates.
        - Have the last dims correspond to the patch dims in the same order.
        - Have, for each dim of patch_dim, (size(dim) - patch_dim(dim)) divisible by stride(dim).

        The batches passed to self.reconstruct should:
        - Have the last dims correspond to the patch dims in the same order.


    Attributes:
        da (xarray.DataArray): The input data array.
        patch_dims (dict): Dimensions and sizes of patches to extract.
        domain_limits (dict): Limits for selecting a subset of the domain.
        strides (dict): Strides for patch extraction.
        check_full_scan (bool): Whether to check if the entire domain is scanned.
        check_dim_order (bool): Whether to check the dimension ordering.
        postpro_fn (callable): A function for post-processing extracted patches.

    """

    def __init__(
            self, da, patch_dims, domain_limits=None, strides=None,
            check_full_scan=False, check_dim_order=False,
            postpro_fn=None
    ):
        """
        Initialize the XrDataset.

        Args:
            da (xarray.DataArray): Input data, with patch dims at the end in the dim orders
            patch_dims (dict):  da dimension and sizes of patches to extract.
            domain_limits (dict, optional): da dimension slices of domain, to Limits for selecting
                                            a subset of the domain. for patch extractions
            strides (dict, optional): dims to strides size for patch extraction.(default to one)
            check_full_scan (bool, optional): if True raise an error if the whole domain is not scanned by the patch.
            check_dim_order (bool, optional): Whether to check the dimension ordering.
            postpro_fn (callable, optional): A function for post-processing extracted patches.
        """
        super().__init__()
        self.return_coords = False
        self.postpro_fn = postpro_fn
        self.da = da.sel(**(domain_limits or {}))
        self.patch_dims = patch_dims
        self.strides = strides or {}
        da_dims = dict(zip(self.da.dims, self.da.shape))
        self.ds_size = {
            dim: max((da_dims[dim] - patch_dims[dim]) // self.strides.get(dim, 1) + 1, 0)
            for dim in patch_dims
        }

        if check_full_scan:
            for dim in patch_dims:
                if (da_dims[dim] - self.patch_dims[dim]) % self.strides.get(dim, 1) != 0:
                    raise IncompleteScanConfiguration(
                        f"""
                        Incomplete scan in dimension dim {dim}:
                        dataarray shape on this dim {da_dims[dim]}
                        patch_size along this dim {self.patch_dims[dim]}
                        stride along this dim {self.strides.get(dim, 1)}
                        [shape - patch_size] should be divisible by stride
                        """
                    )

        if check_dim_order:
            for dim in patch_dims:
                if not '#'.join(da.dims).endswith('#'.join(list(patch_dims))):
                    raise DangerousDimOrdering(
                        f"""
                        input dataarray's dims should end with patch_dims
                        dataarray's dim {da.dims}:
                        patch_dims {list(patch_dims)}
                        """
                    )

    def __len__(self):
        """
        Return the total number of patches in the dataset.

        Returns:
            int: Number of patches.
        """
        size = 1
        for v in self.ds_size.values():
            size *= v
        return size

    def __iter__(self):
        """
        Iterate over the dataset.

        Yields:
            Patch data for each index.
        """
        for i in range(len(self)):
            yield self[i]

    def get_coords(self):
        """
        Get the coordinates of all patches in the dataset.

        Returns:
            list: List of coordinates for each patch.
        """
        self.return_coords = True
        coords = []
        try:
            for i in range(len(self)):
                 coords.append(self[i])
        finally:
            self.return_coords = False
            return coords

    def __getitem__(self, item):
        """
        Get a specific patch by index.

        Args:
            item (int): Index of the patch.

        Returns:
            Patch data or coordinates, depending on the mode.
        """
        sl = {
            dim: slice(self.strides.get(dim, 1) * idx,
                       self.strides.get(dim, 1) * idx + self.patch_dims[dim])
            for dim, idx in zip(self.ds_size.keys(),
                                np.unravel_index(item, tuple(self.ds_size.values())))
        }
        item = self.da.isel(**sl)

        if self.return_coords:
            return item.coords.to_dataset()[list(self.patch_dims)]

        item = item.data.astype(np.float32)
        if self.postpro_fn is not None:
            return self.postpro_fn(item)
        return item

    def reconstruct(self, batches, weight=None):
        """
        Reconstruct the original data array from patches.

        Takes as input a list of np.ndarray of dimensions (b, *, *patch_dims).

        Args:
            batches (list): List of patches (torch tensor) corresponding to batches without shuffle.
            weight (np.ndarray, optional): Tensor of size patch_dims corresponding to the weight of a prediction 
                depending on the position on the patch (default to ones everywhere). Overlapping patches will
                be averaged with weighting.

        Returns:
            xarray.DataArray: Reconstructed data array. A stitched xarray.DataArray with the coords of patch_dims.
        """
        items = list(itertools.chain(*batches))
        return self.reconstruct_from_items(items, weight)

    def reconstruct_from_items(self, items, weight=None):
        """
        Reconstruct the original data array from individual items.

        Args:
            items (list): List of individual patches.
            weight (np.ndarray, optional): Weighting for overlapping patches.

        Returns:
            xarray.DataArray: Reconstructed data array.
        """
        if weight is None:
            weight = np.ones(list(self.patch_dims.values()))
        w = xr.DataArray(weight, dims=list(self.patch_dims.keys()))

        coords = self.get_coords()

        new_dims = [f'v{i}' for i in range(len(items[0].shape) - len(coords[0].dims))]
        dims = new_dims + list(coords[0].dims)

        das = [xr.DataArray(it.numpy(), dims=dims, coords=co.coords)
               for it, co in zip(items, coords)]

        da_shape = dict(zip(coords[0].dims, self.da.shape[-len(coords[0].dims):]))
        new_shape = dict(zip(new_dims, items[0].shape[:len(new_dims)]))

        rec_da = xr.DataArray(
            np.zeros([*new_shape.values(), *da_shape.values()]),
            dims=dims,
            coords={d: self.da[d] for d in self.patch_dims}
        )
        count_da = xr.zeros_like(rec_da)

        for da in das:
            rec_da.loc[da.coords] = rec_da.sel(da.coords) + da * w
            count_da.loc[da.coords] = count_da.sel(da.coords) + w

        return rec_da / count_da

__getitem__(item)

Get a specific patch by index.

Parameters:

Name Type Description Default
item int

Index of the patch.

required

Returns:

Type Description

Patch data or coordinates, depending on the mode.

Source code in ocean4dvarnet/data.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def __getitem__(self, item):
    """
    Get a specific patch by index.

    Args:
        item (int): Index of the patch.

    Returns:
        Patch data or coordinates, depending on the mode.
    """
    sl = {
        dim: slice(self.strides.get(dim, 1) * idx,
                   self.strides.get(dim, 1) * idx + self.patch_dims[dim])
        for dim, idx in zip(self.ds_size.keys(),
                            np.unravel_index(item, tuple(self.ds_size.values())))
    }
    item = self.da.isel(**sl)

    if self.return_coords:
        return item.coords.to_dataset()[list(self.patch_dims)]

    item = item.data.astype(np.float32)
    if self.postpro_fn is not None:
        return self.postpro_fn(item)
    return item

__init__(da, patch_dims, domain_limits=None, strides=None, check_full_scan=False, check_dim_order=False, postpro_fn=None)

Initialize the XrDataset.

Parameters:

Name Type Description Default
da DataArray

Input data, with patch dims at the end in the dim orders

required
patch_dims dict

da dimension and sizes of patches to extract.

required
domain_limits dict

da dimension slices of domain, to Limits for selecting a subset of the domain. for patch extractions

None
strides dict

dims to strides size for patch extraction.(default to one)

None
check_full_scan bool

if True raise an error if the whole domain is not scanned by the patch.

False
check_dim_order bool

Whether to check the dimension ordering.

False
postpro_fn callable

A function for post-processing extracted patches.

None
Source code in ocean4dvarnet/data.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def __init__(
        self, da, patch_dims, domain_limits=None, strides=None,
        check_full_scan=False, check_dim_order=False,
        postpro_fn=None
):
    """
    Initialize the XrDataset.

    Args:
        da (xarray.DataArray): Input data, with patch dims at the end in the dim orders
        patch_dims (dict):  da dimension and sizes of patches to extract.
        domain_limits (dict, optional): da dimension slices of domain, to Limits for selecting
                                        a subset of the domain. for patch extractions
        strides (dict, optional): dims to strides size for patch extraction.(default to one)
        check_full_scan (bool, optional): if True raise an error if the whole domain is not scanned by the patch.
        check_dim_order (bool, optional): Whether to check the dimension ordering.
        postpro_fn (callable, optional): A function for post-processing extracted patches.
    """
    super().__init__()
    self.return_coords = False
    self.postpro_fn = postpro_fn
    self.da = da.sel(**(domain_limits or {}))
    self.patch_dims = patch_dims
    self.strides = strides or {}
    da_dims = dict(zip(self.da.dims, self.da.shape))
    self.ds_size = {
        dim: max((da_dims[dim] - patch_dims[dim]) // self.strides.get(dim, 1) + 1, 0)
        for dim in patch_dims
    }

    if check_full_scan:
        for dim in patch_dims:
            if (da_dims[dim] - self.patch_dims[dim]) % self.strides.get(dim, 1) != 0:
                raise IncompleteScanConfiguration(
                    f"""
                    Incomplete scan in dimension dim {dim}:
                    dataarray shape on this dim {da_dims[dim]}
                    patch_size along this dim {self.patch_dims[dim]}
                    stride along this dim {self.strides.get(dim, 1)}
                    [shape - patch_size] should be divisible by stride
                    """
                )

    if check_dim_order:
        for dim in patch_dims:
            if not '#'.join(da.dims).endswith('#'.join(list(patch_dims))):
                raise DangerousDimOrdering(
                    f"""
                    input dataarray's dims should end with patch_dims
                    dataarray's dim {da.dims}:
                    patch_dims {list(patch_dims)}
                    """
                )

__iter__()

Iterate over the dataset.

Yields:

Type Description

Patch data for each index.

Source code in ocean4dvarnet/data.py
143
144
145
146
147
148
149
150
151
def __iter__(self):
    """
    Iterate over the dataset.

    Yields:
        Patch data for each index.
    """
    for i in range(len(self)):
        yield self[i]

__len__()

Return the total number of patches in the dataset.

Returns:

Name Type Description
int

Number of patches.

Source code in ocean4dvarnet/data.py
131
132
133
134
135
136
137
138
139
140
141
def __len__(self):
    """
    Return the total number of patches in the dataset.

    Returns:
        int: Number of patches.
    """
    size = 1
    for v in self.ds_size.values():
        size *= v
    return size

get_coords()

Get the coordinates of all patches in the dataset.

Returns:

Name Type Description
list

List of coordinates for each patch.

Source code in ocean4dvarnet/data.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def get_coords(self):
    """
    Get the coordinates of all patches in the dataset.

    Returns:
        list: List of coordinates for each patch.
    """
    self.return_coords = True
    coords = []
    try:
        for i in range(len(self)):
             coords.append(self[i])
    finally:
        self.return_coords = False
        return coords

reconstruct(batches, weight=None)

Reconstruct the original data array from patches.

Takes as input a list of np.ndarray of dimensions (b, , patch_dims).

Parameters:

Name Type Description Default
batches list

List of patches (torch tensor) corresponding to batches without shuffle.

required
weight ndarray

Tensor of size patch_dims corresponding to the weight of a prediction depending on the position on the patch (default to ones everywhere). Overlapping patches will be averaged with weighting.

None

Returns:

Type Description

xarray.DataArray: Reconstructed data array. A stitched xarray.DataArray with the coords of patch_dims.

Source code in ocean4dvarnet/data.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
def reconstruct(self, batches, weight=None):
    """
    Reconstruct the original data array from patches.

    Takes as input a list of np.ndarray of dimensions (b, *, *patch_dims).

    Args:
        batches (list): List of patches (torch tensor) corresponding to batches without shuffle.
        weight (np.ndarray, optional): Tensor of size patch_dims corresponding to the weight of a prediction 
            depending on the position on the patch (default to ones everywhere). Overlapping patches will
            be averaged with weighting.

    Returns:
        xarray.DataArray: Reconstructed data array. A stitched xarray.DataArray with the coords of patch_dims.
    """
    items = list(itertools.chain(*batches))
    return self.reconstruct_from_items(items, weight)

reconstruct_from_items(items, weight=None)

Reconstruct the original data array from individual items.

Parameters:

Name Type Description Default
items list

List of individual patches.

required
weight ndarray

Weighting for overlapping patches.

None

Returns:

Type Description

xarray.DataArray: Reconstructed data array.

Source code in ocean4dvarnet/data.py
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
def reconstruct_from_items(self, items, weight=None):
    """
    Reconstruct the original data array from individual items.

    Args:
        items (list): List of individual patches.
        weight (np.ndarray, optional): Weighting for overlapping patches.

    Returns:
        xarray.DataArray: Reconstructed data array.
    """
    if weight is None:
        weight = np.ones(list(self.patch_dims.values()))
    w = xr.DataArray(weight, dims=list(self.patch_dims.keys()))

    coords = self.get_coords()

    new_dims = [f'v{i}' for i in range(len(items[0].shape) - len(coords[0].dims))]
    dims = new_dims + list(coords[0].dims)

    das = [xr.DataArray(it.numpy(), dims=dims, coords=co.coords)
           for it, co in zip(items, coords)]

    da_shape = dict(zip(coords[0].dims, self.da.shape[-len(coords[0].dims):]))
    new_shape = dict(zip(new_dims, items[0].shape[:len(new_dims)]))

    rec_da = xr.DataArray(
        np.zeros([*new_shape.values(), *da_shape.values()]),
        dims=dims,
        coords={d: self.da[d] for d in self.patch_dims}
    )
    count_da = xr.zeros_like(rec_da)

    for da in das:
        rec_da.loc[da.coords] = rec_da.sel(da.coords) + da * w
        count_da.loc[da.coords] = count_da.sel(da.coords) + w

    return rec_da / count_da