Skip to content

Commit db79f5e

Browse files
committed
Fix device mismatch when moving TransformedEmulator to a new device
Use in calibration workflows
1 parent 8557027 commit db79f5e

5 files changed

Lines changed: 87 additions & 6 deletions

File tree

autoemulate/calibration/bayes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(
8585
calibration_params = list(parameter_range.keys())
8686
self.calibration_params = calibration_params
8787
self.emulator = emulator
88-
self.emulator.device = self.device
88+
self.emulator.to(self.device)
8989
self.output_names = list(observations.keys())
9090
self.logger, self.progress_bar = get_configured_logger(log_level)
9191
self.logger.info(

autoemulate/calibration/history_matching.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def __init__(
364364
raise ValueError(msg)
365365

366366
self.transformed_emulator_params = transformed_emulator_params or {}
367-
self.emulator.device = self.device
367+
self.emulator.to(self.device)
368368

369369
# New data is simulated in `run()` and appended here
370370
# It can be used to refit the emulator

autoemulate/calibration/interval_excursion_set.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def __init__(
131131
self.calibration_params = list(parameters_range.keys())
132132
self.d = len(self.parameters_range)
133133
self.emulator = emulator
134-
self.emulator.device = self.device
134+
self.emulator.to(self.device)
135135
# TODO: we might want to check that the len equals the number of tasks returned
136136
self.output_names = output_names
137137
self.logger, self.progress_bar = get_configured_logger(log_level)

autoemulate/emulators/base.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
from torch.distributions import TransformedDistribution
99
from torch.optim.lr_scheduler import ExponentialLR, LRScheduler
1010

11-
from autoemulate.core.device import TorchDeviceMixin
11+
from autoemulate.core.device import TorchDeviceMixin, get_torch_device
1212
from autoemulate.core.types import (
13+
DeviceLike,
1314
DistributionLike,
1415
GaussianLike,
1516
NumpyLike,
@@ -37,6 +38,30 @@ class Emulator(ABC, ValidationMixin, ConversionMixin, TorchDeviceMixin):
3738
y_transform: StandardizeTransform | None = None
3839
supports_uq: bool = False
3940

41+
def to(self, device: DeviceLike) -> "Emulator": # type: ignore[override]
42+
"""
43+
Move the emulator to the given device.
44+
45+
Subclasses may override this to move additional state (e.g. transforms,
46+
cached tensors). The base implementation updates ``self.device`` and, for
47+
emulators that are also ``nn.Module`` instances, delegates to
48+
``nn.Module.to()`` to move parameters and buffers.
49+
50+
Parameters
51+
----------
52+
device: DeviceLike
53+
The target device.
54+
55+
Returns
56+
-------
57+
Emulator
58+
``self``, for method chaining.
59+
"""
60+
self.device = get_torch_device(device)
61+
if isinstance(self, nn.Module):
62+
nn.Module.to(self, self.device)
63+
return self
64+
4065
@abstractmethod
4166
def _fit(self, x: TensorLike, y: TensorLike): ...
4267

@@ -528,7 +553,9 @@ def predict(self, x: TensorLike, with_grad: bool = False) -> GaussianLike:
528553
return pred
529554

530555

531-
class PyTorchBackend(nn.Module, Emulator):
556+
class PyTorchBackend( # type: ignore[reportIncompatibleMethodOverride]
557+
nn.Module, Emulator
558+
):
532559
"""
533560
`PyTorchBackend` provides a backend for PyTorch models.
534561

autoemulate/emulators/transformed/base.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import torch
22
from linear_operator.operators import DiagLinearOperator
3+
from torch import nn
34
from torch.distributions import ComposeTransform, Transform, TransformedDistribution
45
from torch.func import jacrev
56

6-
from autoemulate.core.device import TorchDeviceMixin
7+
from autoemulate.core.device import TorchDeviceMixin, get_torch_device
78
from autoemulate.core.types import (
89
DeviceLike,
910
DistributionLike,
@@ -203,6 +204,59 @@ def _fit_transforms(self, x: TensorLike, y: TensorLike):
203204
all(self._y_transforms_affine) if self._y_transforms_affine else False
204205
)
205206

207+
def to(self, device: DeviceLike) -> "TransformedEmulator":
208+
"""
209+
Move the emulator and all its state to the given device.
210+
211+
Moves the underlying model, transforms, and cached tensors to ``device``.
212+
213+
Parameters
214+
----------
215+
device: DeviceLike
216+
The target device (e.g. ``"cpu"``, ``"mps"``, ``"cuda"``).
217+
218+
Returns
219+
-------
220+
TransformedEmulator
221+
``self``, for method chaining.
222+
"""
223+
device = get_torch_device(device)
224+
self.device = device
225+
226+
# Move the underlying model (Emulator.to handles nn.Module delegation)
227+
self.model.to(device)
228+
# Clear cached prediction strategies that hold stale device refs
229+
if hasattr(self.model, "_clear_cache"):
230+
self.model._clear_cache() # type: ignore[attr-defined]
231+
232+
# Move the inner model's own transforms (e.g. StandardizeTransform)
233+
for attr in ("x_transform", "y_transform"):
234+
transform = getattr(self.model, attr, None)
235+
if transform is not None:
236+
self._move_transform_to_device(transform, device)
237+
238+
# Move transform state tensors
239+
for transform in self.x_transforms + self.y_transforms:
240+
self._move_transform_to_device(transform, device)
241+
242+
# Move cached Jacobian
243+
if self._fixed_jacobian_y_inv is not None:
244+
self._fixed_jacobian_y_inv = self._fixed_jacobian_y_inv.to(device)
245+
246+
return self
247+
248+
@staticmethod
249+
def _move_transform_to_device(transform: Transform, device: torch.device) -> None:
250+
"""Move a transform's tensor attributes to the given device."""
251+
if isinstance(transform, nn.Module):
252+
transform.to(device)
253+
if hasattr(transform, "device"):
254+
object.__setattr__(transform, "device", device)
255+
for attr_name in list(vars(transform)):
256+
val = getattr(transform, attr_name)
257+
if isinstance(val, torch.Tensor):
258+
setattr(transform, attr_name, val.to(device))
259+
206260
def refit(self, x: TensorLike, y: TensorLike, retrain_transforms: bool = False):
207261
"""
208262
Refit the emulator with new data and optionally retrain transforms.

0 commit comments

Comments
 (0)