Skip to content

Commit 0da5ae5

Browse files
authored
Merge pull request #794 from alan-turing-institute/788-reaction-diffusion-dataset
Add dataset for reaction diffusion example (#788)
2 parents ee81d5e + 29cdf1a commit 0da5ae5

9 files changed

Lines changed: 367 additions & 41 deletions

File tree

autoemulate/emulators/base.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from torch import nn, optim
99
from torch.distributions import TransformedDistribution
1010
from torch.optim.lr_scheduler import ExponentialLR, LRScheduler
11-
from torch.utils.data import DataLoader
1211

1312
from autoemulate.core.device import TorchDeviceMixin
1413
from autoemulate.core.types import (
@@ -40,9 +39,9 @@ class Emulator(ABC, ValidationMixin, ConversionMixin, TorchDeviceMixin):
4039
supports_uq: bool = False
4140

4241
@abstractmethod
43-
def _fit(self, x: TensorLike | DataLoader, y: TensorLike | DataLoader | None): ...
42+
def _fit(self, x: TensorLike, y: TensorLike): ...
4443

45-
def fit(self, x: TensorLike | DataLoader, y: TensorLike | DataLoader | None):
44+
def fit(self, x: TensorLike, y: TensorLike):
4645
"""Fit the emulator to the provided data."""
4746
if isinstance(x, TensorLike) and isinstance(y, TensorLike):
4847
self._check(x, y)
@@ -62,11 +61,6 @@ def fit(self, x: TensorLike | DataLoader, y: TensorLike | DataLoader | None):
6261

6362
# Fit emulator
6463
self._fit(x, y)
65-
elif isinstance(x, DataLoader) and y is None:
66-
self._fit(x, y)
67-
else:
68-
msg = "Invalid input types. Expected pair of TensorLike or DataLoader."
69-
raise RuntimeError(msg)
7064
self.is_fitted_ = True
7165

7266
@abstractmethod
@@ -547,7 +541,7 @@ def loss_func(self, y_pred, y_true):
547541
"""Loss function to be used for training the model."""
548542
return nn.MSELoss()(y_pred, y_true)
549543

550-
def _fit(self, x: TensorLike, y: TensorLike): # type: ignore since this is valid subclass of types
544+
def _fit(self, x: TensorLike, y: TensorLike):
551545
"""
552546
Train a PyTorchBackend model.
553547
@@ -671,7 +665,7 @@ class SklearnBackend(DeterministicEmulator):
671665
def _model_specific_check(self, x: NumpyLike, y: NumpyLike):
672666
_, _ = x, y
673667

674-
def _fit(self, x: TensorLike, y: TensorLike): # type: ignore since this is valid subclass of types
668+
def _fit(self, x: TensorLike, y: TensorLike):
675669
if self.normalize_y:
676670
y, y_mean, y_std = self._normalize(y)
677671
self.y_mean = y_mean

autoemulate/emulators/ensemble.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def get_tune_params() -> TuneParams:
6767
"""Return a dictionary of hyperparameters to tune."""
6868
return {}
6969

70-
def _fit(self, x: TensorLike, y: TensorLike) -> None: # type: ignore since this is valid subclass of types
70+
def _fit(self, x: TensorLike, y: TensorLike) -> None:
7171
for e in self.emulators:
7272
e.fit(x, y)
7373
self.is_fitted_ = True
@@ -248,7 +248,7 @@ def get_tune_params() -> TuneParams:
248248
"n_samples": [10, 20, 50, 100],
249249
}
250250

251-
def _fit(self, x: TensorLike, y: TensorLike) -> None: # type: ignore since this is valid subclass of types
251+
def _fit(self, x: TensorLike, y: TensorLike) -> None:
252252
# Delegate training to the wrapped model
253253
self.model.fit(x, y)
254254
self.is_fitted_ = True

autoemulate/emulators/gaussian_process/exact.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def forward(self, x: TensorLike):
183183
MultivariateNormal(mean, covar)
184184
)
185185

186-
def _fit(self, x: TensorLike, y: TensorLike): # type: ignore since this is valid subclass of types
186+
def _fit(self, x: TensorLike, y: TensorLike):
187187
self.train()
188188
self.likelihood.train()
189189

autoemulate/emulators/transformed/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def _inv_transform_y_distribution(self, y_t: DistributionLike) -> DistributionLi
374374
"""
375375
return TransformedDistribution(y_t, [ComposeTransform(self.y_transforms).inv])
376376

377-
def _fit(self, x: TensorLike, y: TensorLike): # type: ignore since this is valid subclass of types
377+
def _fit(self, x: TensorLike, y: TensorLike):
378378
# Transform x and y
379379
x_t = self._transform_x(x)
380380
y_t = self._transform_y_tensor(y)

autoemulate/experimental/data/spatio_temporal_dataset.py renamed to autoemulate/experimental/data/spatiotemporal_dataset.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@ class AutoEmulateDataset(Dataset):
99

1010
def __init__(
1111
self,
12-
data_path: str,
13-
n_steps_input: int,
14-
n_steps_output: int,
12+
data_path: str | None,
13+
data: dict | None = None,
14+
n_steps_input: int = 1,
15+
n_steps_output: int = 1,
1516
stride: int = 1,
1617
# TODO: support for passing data from dict
17-
# data: dict | None = None,
1818
input_channel_idxs: tuple[int, ...] | None = None,
1919
output_channel_idxs: tuple[int, ...] | None = None,
2020
):
@@ -45,20 +45,8 @@ def __init__(
4545
self.input_channel_idxs = input_channel_idxs
4646
self.output_channel_idxs = output_channel_idxs
4747

48-
# TODO: support passing as dict
49-
# Load data
50-
with h5py.File(data_path, "r") as f:
51-
assert "data" in f, "HDF5 file must contain 'data' dataset"
52-
self.data: TensorLike = torch.Tensor(f["data"][:]) # type: ignore # [N, T, W, H, C] # noqa: PGH003
53-
print(f"Loaded data shape: {self.data.shape}")
54-
# TODO: add the constant scalars
55-
self.constant_scalars = (
56-
torch.Tensor(f["constant_scalars"][:]) # type: ignore # noqa: PGH003
57-
if "constant_scalars" in f
58-
else None
59-
) # [N, C]
60-
# TODO: add the constant fields
61-
# self.constant_fields = torch.Tensor(f['data'][:]) # [N, W, H, C]
48+
# Read or parse data
49+
self.read_data(data_path) if data_path is not None else self.parse_data(data)
6250

6351
# Destructured here
6452
(
@@ -107,14 +95,45 @@ def __init__(
10795
print(f"Each input sample shape: {self.all_input_fields[0].shape}")
10896
print(f"Each output sample shape: {self.all_output_fields[0].shape}")
10997

98+
def read_data(self, data_path: str):
99+
"""Read data.
100+
101+
By default assumes HDF5 format in `data_path` with correct shape and fields.
102+
"""
103+
# TODO: support passing as dict
104+
# Load data
105+
self.data_path = data_path
106+
with h5py.File(self.data_path, "r") as f:
107+
assert "data" in f, "HDF5 file must contain 'data' dataset"
108+
self.data: TensorLike = torch.Tensor(f["data"][:]) # type: ignore # [N, T, W, H, C] # noqa: PGH003
109+
print(f"Loaded data shape: {self.data.shape}")
110+
# TODO: add the constant scalars
111+
self.constant_scalars = (
112+
torch.Tensor(f["constant_scalars"][:]) # type: ignore # noqa: PGH003
113+
if "constant_scalars" in f
114+
else None
115+
) # [N, C]
116+
# TODO: add the constant fields
117+
# self.constant_fields = torch.Tensor(f['data'][:]) # [N, W, H, C]
118+
119+
def parse_data(self, data: dict | None):
120+
"""Parse data from a dictionary."""
121+
if data is not None:
122+
self.data = data["data"]
123+
self.constant_scalars = data.get("constant_scalars", None)
124+
self.constant_fields = data.get("constant_fields", None)
125+
return
126+
msg = "No data provided to parse."
127+
raise ValueError(msg)
128+
110129
def __len__(self): # noqa: D105
111130
return len(self.all_input_fields)
112131

113132
def __getitem__(self, idx): # noqa: D105
114133
return {
115134
"input_fields": self.all_input_fields[idx],
116135
"output_fields": self.all_output_fields[idx],
117-
# "constant_scalars": self.all_constant_scalars[idx],
136+
"constant_scalars": self.all_constant_scalars[idx],
118137
# TODO: add this
119138
# "constant_fields": self.all_constant_fields[idx],
120139
}

autoemulate/experimental/emulators/fno.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
2-
from autoemulate.core.types import TensorLike
3-
from autoemulate.emulators.base import PyTorchBackend
2+
from autoemulate.core.types import OutputLike, TensorLike
3+
from autoemulate.experimental.emulators.spatiotemporal import SpatioTemporalEmulator
44
from neuralop.models import FNO
55
from torch.utils.data import DataLoader
66

@@ -41,14 +41,23 @@ def prepare_batch(sample, channels=(0,), with_constants=True, with_time=False):
4141
return x, y
4242

4343

44-
class FNOEmulator(PyTorchBackend):
44+
class FNOEmulator(SpatioTemporalEmulator):
4545
"""An FNO emulator."""
4646

47-
def __init__(self, x, y, *args, **kwargs):
48-
_, _ = x, y
47+
def __init__(self, x=None, y=None, *args, **kwargs):
48+
_, _ = x, y # Unused
49+
# Ensure parent initialisers run before creating nn.Module attributes
50+
super().__init__()
4951
self.model = FNO(**kwargs)
52+
self.optimizer = torch.optim.Adam(self.model.parameters())
5053

51-
def _fit(self, x: DataLoader, y: DataLoader | None): # type: ignore # noqa: PGH003
54+
@staticmethod
55+
def is_multioutput() -> bool: # noqa: D102
56+
return True
57+
58+
def _fit(self, x: TensorLike | DataLoader, y: TensorLike | None = None):
59+
assert isinstance(x, DataLoader), "x currently must be a DataLoader"
60+
assert y is None, "y currently must be None"
5261
channels = (0,) # Which channel to use
5362
for idx, batch in enumerate(x):
5463
# Prepare input with constants
@@ -73,5 +82,16 @@ def forward(self, x: TensorLike):
7382
"""Forward pass."""
7483
return self.model(x)
7584

76-
def _predict(self, x, with_grad):
77-
return super()._predict(x, with_grad)
85+
def _predict(self, x: TensorLike | DataLoader, with_grad: bool) -> OutputLike:
86+
assert isinstance(x, DataLoader), "x currently must be a DataLoader"
87+
with torch.set_grad_enabled(with_grad):
88+
channels = (0,) # Which channel to use
89+
all_preds = []
90+
for _, batch in enumerate(x):
91+
# Prepare input with constants
92+
x, _ = prepare_batch(
93+
batch, channels=channels, with_constants=True, with_time=True
94+
)
95+
out = self(x)
96+
all_preds.append(out)
97+
return torch.cat(all_preds)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from abc import abstractmethod
2+
3+
from autoemulate.core.types import OutputLike, TensorLike
4+
from autoemulate.emulators.base import PyTorchBackend
5+
from torch.utils.data import DataLoader
6+
7+
8+
class SpatioTemporalEmulator(PyTorchBackend):
9+
"""A spatio-temporal backend for emulators."""
10+
11+
def fit(self, x: TensorLike | DataLoader, y: TensorLike | None = None):
12+
"""Train a spatio-temporal emulator.
13+
14+
Parameters
15+
----------
16+
x: TensorLike | DataLoader
17+
Input features as `TensorLike` or `DataLoader`.
18+
y: OutputLike | None
19+
Target values (not needed if x is a DataLoader).
20+
21+
"""
22+
if isinstance(x, TensorLike) and isinstance(y, TensorLike):
23+
return super().fit(x, y)
24+
if isinstance(x, DataLoader) and y is None:
25+
return self._fit(x, y)
26+
msg = "Invalid input types. Expected pair of TensorLike or DataLoader only."
27+
raise RuntimeError(msg)
28+
29+
@abstractmethod
30+
def _fit(self, x: TensorLike | DataLoader, y: TensorLike | None = None): ...
31+
32+
def predict(
33+
self, x: TensorLike | DataLoader, with_grad: bool = False
34+
) -> OutputLike:
35+
"""Predict the output for the given input.
36+
37+
Parameters
38+
----------
39+
x: TensorLike | DataLoader
40+
Input `TensorLike` or `DataLoader` to make predictions for.
41+
with_grad: bool
42+
Whether to enable gradient calculation. Defaults to False.
43+
44+
Returns
45+
-------
46+
OutputLike
47+
The emulator predicted spatiotemporal output.
48+
"""
49+
if isinstance(x, TensorLike):
50+
return super().predict(x, with_grad)
51+
return self._predict(x, with_grad)
52+
53+
@abstractmethod
54+
def _predict(self, x: TensorLike | DataLoader, with_grad: bool) -> OutputLike: ...
55+
56+
# TODO: add method for rollout predictions
57+
# def predict_rollout(self, x: DataLoader, timesteps: int = 1) -> OutputLike:
58+
# """
59+
# Predict the output for the given input, rolling out for a number of timesteps.
60+
61+
# Parameters
62+
# ----------
63+
# x: DataLoader
64+
# Input `DataLoader` to make predictions for.
65+
# timesteps: int
66+
# Number of timesteps to rollout for. Defaults to 1.
67+
68+
# Returns
69+
# -------
70+
# OutputLike
71+
# The emulator predicted spatiotemporal output.
72+
# """
73+
74+
# # Start at t=0 x_0
75+
# # model predicts x_1 given x_0
76+
# # then model predicts x_2 given model's predicted x_1
77+
# # then model predicts x_3 given model's predicted x_2

0 commit comments

Comments
 (0)