Skip to content

lorenz63.models

RearrangedBilinAEPriorCost

Bases: BilinAEPriorCost

Wrapper around the base prior cost that allows for reshaping of the input batch Used to convert the lorenz timeseries into an "image" for reuse of conv2d layers

Source code in contrib/lorenz63/models.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class RearrangedBilinAEPriorCost(src.models.BilinAEPriorCost):
    """
    Wrapper around the base prior cost that allows for reshaping of the input batch
    Used to convert the lorenz timeseries into an "image" for reuse of conv2d layers
    """
    def __init__(self, rearrange_from='b c t', rearrange_to='b t c ()', *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.rearrange_bef = rearrange_from + ' -> ' + rearrange_to
        self.rearrange_aft = rearrange_to + ' -> ' + rearrange_from

    def forward_ae(self, x):
        x = einops.rearrange(x, self.rearrange_bef)
        x = super().forward_ae(x)
        x = einops.rearrange(x, self.rearrange_aft)
        return x

RearrangedConvLstmGradModel

Bases: ConvLstmGradModel

Wrapper around the base grad model that allows for reshaping of the input batch Used to convert the lorenz timeseries into an "image" for reuse of conv2d layers

Source code in contrib/lorenz63/models.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class RearrangedConvLstmGradModel(src.models.ConvLstmGradModel):
    """
    Wrapper around the base grad model that allows for reshaping of the input batch
    Used to convert the lorenz timeseries into an "image" for reuse of conv2d layers
    """
    def __init__(self, rearrange_from='b c t', rearrange_to='b t c ()', *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.rearrange_bef = rearrange_from + ' -> ' + rearrange_to
        self.rearrange_aft = rearrange_to + ' -> ' + rearrange_from

    def reset_state(self, inp):
        inp = einops.rearrange(inp, self.rearrange_bef)
        super().reset_state(inp)

    def forward(self, x):
        x = einops.rearrange(x, self.rearrange_bef)
        x = super().forward(x)
        x = einops.rearrange(x, self.rearrange_aft)
        return x