Skip to content

Commit a7d773e

Browse files
authored
Merge pull request #820 from alan-turing-institute/advection-reaction
Add dataset for advection-diffusion example
2 parents cea1253 + bf9d34c commit a7d773e

2 files changed

Lines changed: 416 additions & 0 deletions

File tree

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "0",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"from autoemulate.simulations.advection_diffusion import AdvectionDiffusion\n",
11+
"\n",
12+
"\n",
13+
"rd = AdvectionDiffusion(n=50, T=80, dt=0.25, return_timeseries=True)\n",
14+
"data = rd.forward_samples_spatiotemporal(20)\n",
15+
"y = data[\"data\"]"
16+
]
17+
},
18+
{
19+
"cell_type": "code",
20+
"execution_count": null,
21+
"id": "1",
22+
"metadata": {},
23+
"outputs": [],
24+
"source": [
25+
"import matplotlib.pyplot as plt\n",
26+
"\n",
27+
"plt.imshow(y[0, 50, :, :, 0])\n"
28+
]
29+
},
30+
{
31+
"cell_type": "code",
32+
"execution_count": null,
33+
"id": "2",
34+
"metadata": {},
35+
"outputs": [],
36+
"source": [
37+
"from torch.utils.data import DataLoader\n",
38+
"from autoemulate.experimental.data.spatiotemporal_dataset import AutoEmulateDataset\n",
39+
"\n",
40+
"dataset = AutoEmulateDataset(data_path=None, data=data, n_steps_input=1, n_steps_output=1)\n",
41+
"batch = next(iter(DataLoader(dataset)))"
42+
]
43+
},
44+
{
45+
"cell_type": "code",
46+
"execution_count": null,
47+
"id": "3",
48+
"metadata": {},
49+
"outputs": [],
50+
"source": [
51+
"batch[\"input_fields\"].shape, batch[\"output_fields\"].shape, batch[\"constant_scalars\"].shape"
52+
]
53+
},
54+
{
55+
"cell_type": "code",
56+
"execution_count": null,
57+
"id": "4",
58+
"metadata": {},
59+
"outputs": [],
60+
"source": [
61+
"batch[\"input_fields\"].shape"
62+
]
63+
},
64+
{
65+
"cell_type": "code",
66+
"execution_count": null,
67+
"id": "5",
68+
"metadata": {},
69+
"outputs": [],
70+
"source": [
71+
"import matplotlib.pyplot as plt\n",
72+
"\n",
73+
"plt.imshow(batch[\"input_fields\"][0, 0, :, :, 0].cpu())\n",
74+
"plt.show()"
75+
]
76+
},
77+
{
78+
"cell_type": "code",
79+
"execution_count": null,
80+
"id": "6",
81+
"metadata": {},
82+
"outputs": [],
83+
"source": [
84+
"from autoemulate.experimental.emulators.fno import FNOEmulator\n",
85+
"\n",
86+
"emulator = FNOEmulator(\n",
87+
" n_modes=(1, 16, 16),\n",
88+
" hidden_channels=16,\n",
89+
" in_channels=3,\n",
90+
" out_channels=1,\n",
91+
")\n"
92+
]
93+
},
94+
{
95+
"cell_type": "code",
96+
"execution_count": null,
97+
"id": "7",
98+
"metadata": {},
99+
"outputs": [],
100+
"source": [
101+
"# Fit the emulator\n",
102+
"emulator.fit(DataLoader(dataset), None)"
103+
]
104+
},
105+
{
106+
"cell_type": "code",
107+
"execution_count": null,
108+
"id": "8",
109+
"metadata": {},
110+
"outputs": [],
111+
"source": [
112+
"# Predictions\n",
113+
"y_pred = emulator.predict(DataLoader(dataset), with_grad=False)\n"
114+
]
115+
},
116+
{
117+
"cell_type": "code",
118+
"execution_count": null,
119+
"id": "9",
120+
"metadata": {},
121+
"outputs": [],
122+
"source": [
123+
"# Evaluate\n",
124+
"# TODO: add to emulator perhaps as .evaluate()?\n",
125+
"import torch\n",
126+
"from autoemulate.experimental.emulators.fno import prepare_batch\n",
127+
"\n",
128+
"y_true = torch.cat(\n",
129+
" [\n",
130+
" prepare_batch(\n",
131+
" batch, channels=(0,), with_constants=True, with_time=True\n",
132+
" )[1]\n",
133+
" for batch in DataLoader(dataset)\n",
134+
" ],\n",
135+
" dim=0\n",
136+
")\n"
137+
]
138+
},
139+
{
140+
"cell_type": "code",
141+
"execution_count": null,
142+
"id": "10",
143+
"metadata": {},
144+
"outputs": [],
145+
"source": [
146+
"from torchmetrics import R2Score\n",
147+
"\n",
148+
"R2Score()(y_pred.reshape(-1).detach(), y_true.reshape(-1).detach()).item()"
149+
]
150+
}
151+
],
152+
"metadata": {
153+
"kernelspec": {
154+
"display_name": ".venv",
155+
"language": "python",
156+
"name": "python3"
157+
},
158+
"language_info": {
159+
"codemirror_mode": {
160+
"name": "ipython",
161+
"version": 3
162+
},
163+
"file_extension": ".py",
164+
"mimetype": "text/x-python",
165+
"name": "python",
166+
"nbconvert_exporter": "python",
167+
"pygments_lexer": "ipython3",
168+
"version": "3.12.9"
169+
}
170+
},
171+
"nbformat": 4,
172+
"nbformat_minor": 5
173+
}

0 commit comments

Comments
 (0)