Skip to content

Commit dee64c8

Browse files
authored
Merge pull request #861 from alan-turing-institute/add_tutorial
Add tutorial on adding emulators
2 parents 9d8fad4 + 20b0b44 commit dee64c8

6 files changed

Lines changed: 261 additions & 56 deletions

File tree

autoemulate/emulators/gaussian_process/exact.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,7 @@
1010
from torch import nn, optim
1111
from torch.optim.lr_scheduler import LRScheduler
1212

13-
from autoemulate.callbacks.early_stopping import (
14-
EarlyStopping,
15-
EarlyStoppingException,
16-
)
13+
from autoemulate.callbacks.early_stopping import EarlyStopping, EarlyStoppingException
1714
from autoemulate.core.device import TorchDeviceMixin
1815
from autoemulate.core.types import (
1916
DeviceLike,
@@ -23,18 +20,15 @@
2320
)
2421
from autoemulate.data.utils import set_random_seed
2522
from autoemulate.emulators.base import GaussianProcessEmulator
26-
from autoemulate.emulators.gaussian_process import (
27-
CovarModuleFn,
28-
MeanModuleFn,
29-
)
23+
from autoemulate.emulators.gaussian_process import CovarModuleFn, MeanModuleFn
3024
from autoemulate.transforms.standardize import StandardizeTransform
3125
from autoemulate.transforms.utils import make_positive_definite
3226

3327
from .kernel import (
3428
matern_3_2_kernel,
3529
matern_5_2_kernel,
3630
matern_5_2_plus_rq,
37-
rbf,
31+
rbf_kernel,
3832
rbf_plus_constant,
3933
rbf_plus_linear,
4034
rbf_times_linear,
@@ -306,7 +300,7 @@ def get_tune_params():
306300
poly_mean,
307301
],
308302
"covar_module_fn": [
309-
rbf,
303+
rbf_kernel,
310304
matern_5_2_kernel,
311305
matern_3_2_kernel,
312306
rq_kernel,
@@ -556,7 +550,7 @@ def get_tune_params():
556550
GaussianProcessRBF = create_gp_subclass(
557551
"GaussianProcessRBF",
558552
GaussianProcess,
559-
covar_module_fn=rbf,
553+
covar_module_fn=rbf_kernel,
560554
mean_module_fn=constant_mean,
561555
)
562556
GaussianProcessMatern32 = create_gp_subclass(

autoemulate/emulators/gaussian_process/kernel.py

Lines changed: 29 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
)
1010

1111

12-
def rbf(n_features: int | None, n_outputs: torch.Size | None) -> RBFKernel:
12+
def rbf_kernel(n_features: int | None, n_outputs: torch.Size | None) -> RBFKernel:
1313
"""
1414
Radial Basis Function (RBF) kernel.
1515
@@ -132,6 +132,30 @@ def rq_kernel(n_features: int | None, n_outputs: torch.Size | None) -> RQKernel:
132132
)
133133

134134

135+
def linear_kernel(n_features: int | None, n_outputs: torch.Size | None) -> LinearKernel:
136+
"""
137+
Linear kernel.
138+
139+
Parameters
140+
----------
141+
n_features: int | None
142+
Number of input features. If None, the kernel is not initialized with a
143+
lengthscale.
144+
n_outputs: torch.Size | None
145+
Batch shape of the kernel. If None, the kernel is not initialized with a
146+
batch shape.
147+
148+
Returns
149+
-------
150+
LinearKernel
151+
The initialized Linear kernel.
152+
"""
153+
return LinearKernel(
154+
ard_num_dims=n_features,
155+
batch_shape=n_outputs,
156+
)
157+
158+
135159
def rbf_plus_constant(n_features: int | None, n_outputs: torch.Size | None) -> Kernel:
136160
"""
137161
Radial Basis Function (RBF) kernel plus a constant kernel.
@@ -150,13 +174,7 @@ def rbf_plus_constant(n_features: int | None, n_outputs: torch.Size | None) -> K
150174
Kernel
151175
The initialized RBF kernel plus a constant kernel.
152176
"""
153-
rbf_kernel = RBFKernel(
154-
ard_num_dims=n_features,
155-
batch_shape=n_outputs,
156-
)
157-
if n_features is not None:
158-
rbf_kernel.initialize(lengthscale=torch.ones(n_features) * 1.5)
159-
return rbf_kernel + ConstantKernel()
177+
return rbf_kernel(n_features, n_outputs) + ConstantKernel()
160178

161179

162180
# combinations
@@ -178,16 +196,7 @@ def rbf_plus_linear(n_features: int | None, n_outputs: torch.Size | None) -> Ker
178196
Kernel
179197
The initialized RBF kernel plus a linear kernel.
180198
"""
181-
rbf_kernel = RBFKernel(
182-
ard_num_dims=n_features,
183-
batch_shape=n_outputs,
184-
)
185-
if n_features is not None:
186-
rbf_kernel.initialize(lengthscale=torch.ones(n_features) * 1.5)
187-
return rbf_kernel + LinearKernel(
188-
ard_num_dims=n_features,
189-
batch_shape=n_outputs,
190-
)
199+
return rbf_kernel(n_features, n_outputs) + linear_kernel(n_features, n_outputs)
191200

192201

193202
def matern_5_2_plus_rq(n_features: int | None, n_outputs: torch.Size | None) -> Kernel:
@@ -208,20 +217,7 @@ def matern_5_2_plus_rq(n_features: int | None, n_outputs: torch.Size | None) ->
208217
Kernel
209218
The initialized Matern 5/2 kernel plus a Rational Quadratic kernel.
210219
"""
211-
matern_kernel = MaternKernel(
212-
nu=2.5,
213-
ard_num_dims=n_features,
214-
batch_shape=n_outputs,
215-
)
216-
rq_kernel = RQKernel(
217-
ard_num_dims=n_features,
218-
batch_shape=n_outputs,
219-
)
220-
# Initialize lengthscales for both kernels if n_features is provided
221-
if n_features is not None:
222-
matern_kernel.initialize(lengthscale=torch.ones(n_features) * 1.5)
223-
rq_kernel.initialize(lengthscale=torch.ones(n_features) * 1.5)
224-
return matern_kernel + rq_kernel
220+
return matern_5_2_kernel(n_features, n_outputs) + rq_kernel(n_features, n_outputs)
225221

226222

227223
def rbf_times_linear(n_features: int | None, n_outputs: torch.Size | None) -> Kernel:
@@ -241,13 +237,4 @@ def rbf_times_linear(n_features: int | None, n_outputs: torch.Size | None) -> Ke
241237
Kernel
242238
The initialized RBF kernel multiplied by a linear kernel.
243239
"""
244-
rbf_kernel = RBFKernel(
245-
ard_num_dims=n_features,
246-
batch_shape=n_outputs,
247-
)
248-
if n_features is not None:
249-
rbf_kernel.initialize(lengthscale=torch.ones(n_features) * 1.5)
250-
return rbf_kernel * LinearKernel(
251-
ard_num_dims=n_features,
252-
batch_shape=n_outputs,
253-
)
240+
return rbf_kernel(n_features, n_outputs) * linear_kernel(n_features, n_outputs)

docs/_toc.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ chapters:
2323
- file: tutorials/simulator/01_custom_simulations
2424
- file: tutorials/simulator/02_active_learning
2525
- file: tutorials/simulator/03_history_matching
26-
26+
- file: tutorials/advanced/index
27+
sections:
28+
- file: tutorials/advanced/01_add_emulators
2729

2830
- file: community/index
2931
sections:
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "0",
6+
"metadata": {},
7+
"source": [
8+
"# Adding emulators\n",
9+
"\n",
10+
"In addition to providing a library of core emulators, AutoEmulate is designed to be easily extensible. This tutorial walks you through the steps of adding new emulators to the library. We cover two scenarios: adding new Gaussian Process kernels and adding entirely new models."
11+
]
12+
},
13+
{
14+
"cell_type": "markdown",
15+
"id": "1",
16+
"metadata": {},
17+
"source": [
18+
"## 1. Adding Gaussian Process kernels\n",
19+
"\n",
20+
"Gaussian Processes (GPs) are primarily defined by their kernel functions, which determine the covariance structure of the data. AutoEmulate includes several built-in GP kernels:\n",
21+
"- Radial Basis Function (RBF)\n",
22+
"- Matern 3/2\n",
23+
"- Matern 5/2\n",
24+
"- Rational Quadratic (RQ)\n",
25+
"- Linear\n",
26+
"\n",
27+
"You can easily create new kernels by composing any two or more of these existing kernels. For example, you might want to create a kernel that combines the RBF and Linear kernels to capture both smooth variations and linear trends in your data.\n",
28+
"\n",
29+
"In AutoEmulate, each kernel is defined by an initialisation function that takes as inputs the number of data input features and the number of output features. Below we define a custom kernel function following this pattern."
30+
]
31+
},
32+
{
33+
"cell_type": "code",
34+
"execution_count": null,
35+
"id": "2",
36+
"metadata": {},
37+
"outputs": [],
38+
"source": [
39+
"from autoemulate.emulators.gaussian_process.kernel import rbf_kernel, linear_kernel\n",
40+
"\n",
41+
"def rbs_plus_linear_kernel(n_features, n_outputs):\n",
42+
" \"\"\"\n",
43+
" Example of a custom kernel function that combines RBF and linear kernels.\n",
44+
" \"\"\"\n",
45+
" return rbf_kernel(n_features, n_outputs) + linear_kernel(n_features, n_outputs)"
46+
]
47+
},
48+
{
49+
"cell_type": "markdown",
50+
"id": "3",
51+
"metadata": {},
52+
"source": [
53+
"Once this function has been defined, you can create a new GP emulator class using the `create_gp_subclass` function."
54+
]
55+
},
56+
{
57+
"cell_type": "code",
58+
"execution_count": null,
59+
"id": "4",
60+
"metadata": {},
61+
"outputs": [],
62+
"source": [
63+
"from autoemulate.emulators.gaussian_process.exact import GaussianProcess, create_gp_subclass\n",
64+
"\n",
65+
"GaussianProcessRBFandLinear = create_gp_subclass(\n",
66+
" \"GaussianProcessRBFandLinear\", \n",
67+
" GaussianProcess, \n",
68+
" # the custom kernel function goes here\n",
69+
" covar_module_fn=rbs_plus_linear_kernel,\n",
70+
")"
71+
]
72+
},
73+
{
74+
"cell_type": "markdown",
75+
"id": "5",
76+
"metadata": {},
77+
"source": [
78+
"Now we can tell AutoEmulate to use the new GP class by passing it to the `models` argument when initialising an `AutoEmulate` object."
79+
]
80+
},
81+
{
82+
"cell_type": "code",
83+
"execution_count": null,
84+
"id": "6",
85+
"metadata": {},
86+
"outputs": [],
87+
"source": [
88+
"from autoemulate import AutoEmulate\n",
89+
"import torch\n",
90+
"\n",
91+
"# create some example data\n",
92+
"x = torch.linspace(0, 1, 100).unsqueeze(-1)\n",
93+
"y = torch.sin(2 * 3.14 * x) + 0.1 * torch.randn_like(x)\n",
94+
"\n",
95+
"ae = AutoEmulate(x, y, models=[GaussianProcessRBFandLinear])"
96+
]
97+
},
98+
{
99+
"cell_type": "code",
100+
"execution_count": null,
101+
"id": "7",
102+
"metadata": {},
103+
"outputs": [],
104+
"source": [
105+
"ae.summarise()"
106+
]
107+
},
108+
{
109+
"cell_type": "markdown",
110+
"id": "8",
111+
"metadata": {},
112+
"source": [
113+
"## 2. Adding new models\n",
114+
"\n",
115+
"It is also possible to add entirely new models to AutoEmulate. AutoEmulate has a base `Emulator` class that handles most of the general functionality required for training and prediction. To implement a new emulator, one must simply subclass `Emulator` and implement the abstract methods (`_fit`, `_predict` and `is_multioutput`), `get_tune_params` to enable model tuning, as well any model specific functionality and initialisations.\n",
116+
"\n",
117+
"Since AutoEmulate supports a variety of models, there are additional `Emulator` subclasses that handle specific functionality for each model type:\n",
118+
"- `PytorchBackend` for PyTorch models\n",
119+
"- `SklearnBackend` for scikit-learn models\n",
120+
"- `GaussianProcess` for exact Gaussian Process implementations\n",
121+
"- `Ensemble` for ensemble models\n",
122+
"\n",
123+
"Subclassing one of these directly has slightly different requirements. For example, when subclassing `PytorchBackend` or `GaussianProcess`, one must implement the `forward` method to define the model's forward pass.\n",
124+
"\n",
125+
"There are also some static methods that should be implemented to provide metadata about the model, such as `is_multioutput` and `get_tune_params`.\n",
126+
"\n",
127+
"Below demonstrates adding a simple feedforward neural network (FNN) using PyTorch. The new class `SimpleFNN` subclasses `PytorchBackend`, which already handles fitting and prediction."
128+
]
129+
},
130+
{
131+
"cell_type": "code",
132+
"execution_count": null,
133+
"id": "9",
134+
"metadata": {},
135+
"outputs": [],
136+
"source": [
137+
"from autoemulate.core.device import TorchDeviceMixin\n",
138+
"from autoemulate.emulators.base import PyTorchBackend\n",
139+
"import torch.nn as nn\n",
140+
"\n",
141+
"class SimpleFNN(PyTorchBackend):\n",
142+
" def __init__(\n",
143+
" self, \n",
144+
" x, \n",
145+
" y,\n",
146+
" hidden_dim=64,\n",
147+
" device = None,\n",
148+
" ):\n",
149+
" TorchDeviceMixin.__init__(self, device=device)\n",
150+
" nn.Module.__init__(self)\n",
151+
" \n",
152+
" input_dim = x.shape[1]\n",
153+
" output_dim = y.shape[1] if len(y.shape) > 1 else 1\n",
154+
" layers = []\n",
155+
" layers.append(nn.Linear(input_dim, hidden_dim, device=self.device))\n",
156+
" layers.append(nn.ReLU())\n",
157+
" layers.append(nn.Linear(hidden_dim, output_dim, device=self.device))\n",
158+
" self.model = nn.Sequential(*layers)\n",
159+
" self.optimizer = self.optimizer_cls(self.model.parameters(), lr=self.lr) # type: ignore[call-arg] since all optimizers include lr\n",
160+
" self.scheduler = None\n",
161+
" self.to(self.device)\n",
162+
" \n",
163+
" def forward(self, x):\n",
164+
" return self.model(x)\n",
165+
" \n",
166+
" @staticmethod\n",
167+
" def is_multioutput():\n",
168+
" return True\n",
169+
" \n",
170+
" @staticmethod\n",
171+
" def get_tune_params():\n",
172+
" return {\n",
173+
" \"hidden_dim\": [32, 64, 128]\n",
174+
" }"
175+
]
176+
},
177+
{
178+
"cell_type": "code",
179+
"execution_count": null,
180+
"id": "10",
181+
"metadata": {},
182+
"outputs": [],
183+
"source": [
184+
"ae = AutoEmulate(x, y, models=[SimpleFNN])"
185+
]
186+
},
187+
{
188+
"cell_type": "code",
189+
"execution_count": null,
190+
"id": "11",
191+
"metadata": {},
192+
"outputs": [],
193+
"source": [
194+
"ae.summarise()"
195+
]
196+
}
197+
],
198+
"metadata": {
199+
"kernelspec": {
200+
"display_name": ".venv",
201+
"language": "python",
202+
"name": "python3"
203+
},
204+
"language_info": {
205+
"codemirror_mode": {
206+
"name": "ipython",
207+
"version": 3
208+
},
209+
"file_extension": ".py",
210+
"mimetype": "text/x-python",
211+
"name": "python",
212+
"nbconvert_exporter": "python",
213+
"pygments_lexer": "ipython3",
214+
"version": "3.12.11"
215+
}
216+
},
217+
"nbformat": 4,
218+
"nbformat_minor": 5
219+
}

docs/tutorials/advanced/index.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Advanced usage
2+
3+
This section covers some of the more advanced features of AutoEmulate, including adding custom emulators.

0 commit comments

Comments
 (0)