Skip to content

Commit 8afd352

Browse files
authored
Merge pull request #949 from cmsamaaa/feat/torch-simulator-device-support
Add a device-aware TorchSimulator base class
2 parents d3b52ef + bcebe20 commit 8afd352

File tree

2 files changed

+178
-2
lines changed

2 files changed

+178
-2
lines changed

autoemulate/simulations/base.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
from scipy.stats import qmc
66
from tqdm import tqdm
77

8+
from autoemulate.core.device import TorchDeviceMixin
89
from autoemulate.core.logging_config import get_configured_logger
9-
from autoemulate.core.types import TensorLike
10+
from autoemulate.core.types import DeviceLike, TensorLike
1011
from autoemulate.data.utils import ValidationMixin, set_random_seed
1112

1213
logger = logging.getLogger("autoemulate")
@@ -362,3 +363,99 @@ def get_outputs_as_dict(self) -> dict[str, TensorLike]:
362363
output_dict[output_name] = self.results_tensor[:, i]
363364

364365
return output_dict
366+
367+
368+
class TorchSimulator(Simulator, TorchDeviceMixin):
369+
"""
370+
Simulator that runs computations on a specified torch device.
371+
372+
This subclass extends :class:`Simulator` with the :class:`TorchDeviceMixin`
373+
so that simulators implemented in PyTorch (e.g., ``torchcor``) can run on
374+
CPU or accelerator hardware. Inputs are moved to ``self.device`` before the
375+
forward pass and the resulting tensors are kept on the same device.
376+
"""
377+
378+
def __init__(
379+
self,
380+
parameters_range: dict[str, tuple[float, float]],
381+
output_names: list[str],
382+
log_level: str = "progress_bar",
383+
device: DeviceLike | None = None,
384+
):
385+
Simulator.__init__(self, parameters_range, output_names, log_level)
386+
TorchDeviceMixin.__init__(self, device=device)
387+
388+
def sample_inputs(
389+
self, n_samples: int, random_seed: int | None = None, method: str = "lhs"
390+
) -> TensorLike:
391+
"""
392+
Sample inputs and move them to the simulator's device.
393+
394+
Parameters
395+
----------
396+
n_samples: int
397+
Number of samples to generate.
398+
random_seed: int | None
399+
Optional random seed to make sampling reproducible.
400+
method: str
401+
Sampling method, one of ``"lhs"`` or ``"sobol"``.
402+
403+
Returns
404+
-------
405+
TensorLike
406+
Sampled inputs located on ``self.device``.
407+
"""
408+
samples = super().sample_inputs(
409+
n_samples, random_seed=random_seed, method=method
410+
)
411+
(samples_device,) = self._move_tensors_to_device(samples)
412+
return samples_device
413+
414+
def forward(self, x: TensorLike, allow_failures: bool = True) -> TensorLike | None:
415+
"""
416+
Run a single simulation on the configured device.
417+
418+
Parameters
419+
----------
420+
x: TensorLike
421+
Input tensor with shape ``(n_samples, in_dim)``.
422+
allow_failures: bool
423+
When True, failures return ``None`` instead of raising.
424+
425+
Returns
426+
-------
427+
TensorLike | None
428+
Simulation result on ``self.device`` or ``None`` on failure.
429+
"""
430+
(x_device,) = self._move_tensors_to_device(x)
431+
y = super().forward(x_device, allow_failures=allow_failures)
432+
if isinstance(y, torch.Tensor):
433+
return y.to(self.device)
434+
return y
435+
436+
def forward_batch(
437+
self, x: TensorLike, allow_failures: bool = True
438+
) -> tuple[TensorLike, TensorLike]:
439+
"""
440+
Run a batch of simulations with device management.
441+
442+
Parameters
443+
----------
444+
x: TensorLike
445+
Batch of inputs with shape ``(batch_size, in_dim)``.
446+
allow_failures: bool
447+
Whether to skip failures (True) or raise immediately (False).
448+
449+
Returns
450+
-------
451+
tuple[TensorLike, TensorLike]
452+
Tuple of ``(results, valid_inputs)`` residing on ``self.device``.
453+
"""
454+
(x_device,) = self._move_tensors_to_device(x)
455+
results, valid_inputs = super().forward_batch(
456+
x_device, allow_failures=allow_failures
457+
)
458+
results = results.to(self.device)
459+
self.results_tensor = results
460+
valid_inputs = valid_inputs.to(self.device)
461+
return results, valid_inputs

tests/simulations/test_base_simulator.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
import torch
33
from autoemulate.core.types import TensorLike
4-
from autoemulate.simulations.base import Simulator
4+
from autoemulate.simulations.base import Simulator, TorchSimulator
55
from torch import Tensor
66

77

@@ -33,6 +33,25 @@ def _forward(self, x: TensorLike) -> TensorLike | None:
3333
return torch.cat(outputs, dim=1)
3434

3535

36+
class TorchMockSimulator(TorchSimulator):
37+
"""Torch-based simulator for testing TorchSimulator behaviour."""
38+
39+
def __init__(
40+
self,
41+
parameters_range: dict[str, tuple[float, float]],
42+
output_names: list[str],
43+
device: str | torch.device | None = "cpu",
44+
):
45+
super().__init__(parameters_range, output_names, device=device)
46+
47+
def _forward(self, x: TensorLike) -> TensorLike | None:
48+
outputs = []
49+
for i, _ in enumerate(self._output_names):
50+
output = torch.sum(x, dim=1) * (i + 1)
51+
outputs.append(output.view(-1, 1))
52+
return torch.cat(outputs, dim=1)
53+
54+
3655
@pytest.fixture
3756
def parameters_range():
3857
"""Create test parameter ranges"""
@@ -324,3 +343,63 @@ def test_forward_batch_allow_failures():
324343
assert len(results) == 2 # All simulations successful
325344
assert len(valid_inputs) == 2
326345
assert torch.allclose(valid_inputs, success_batch)
346+
347+
348+
def test_torch_simulator_initializes_device(parameters_range):
349+
sim = TorchMockSimulator(parameters_range, ["var1", "var2"], device="cpu")
350+
assert sim.device == torch.device("cpu")
351+
352+
353+
def test_torch_simulator_forward_moves_inputs(parameters_range, monkeypatch):
354+
sim = TorchMockSimulator(parameters_range, ["var1", "var2"], device="cpu")
355+
original_move = TorchSimulator._move_tensors_to_device
356+
calls = {"count": 0}
357+
358+
def recording_move(self, *args):
359+
calls["count"] += 1
360+
return original_move(self, *args)
361+
362+
monkeypatch.setattr(TorchSimulator, "_move_tensors_to_device", recording_move)
363+
sim.forward(torch.tensor([[0.1, 0.2, 0.3]], dtype=torch.float32))
364+
assert calls["count"] == 1
365+
366+
367+
def test_torch_simulator_forward_batch_moves_inputs(parameters_range, monkeypatch):
368+
sim = TorchMockSimulator(parameters_range, ["var1"], device="cpu")
369+
original_move = TorchSimulator._move_tensors_to_device
370+
calls = {"count": 0}
371+
372+
def recording_move(self, *args):
373+
calls["count"] += 1
374+
return original_move(self, *args)
375+
376+
monkeypatch.setattr(TorchSimulator, "_move_tensors_to_device", recording_move)
377+
batch = torch.tensor(
378+
[
379+
[0.2, 0.5, 0.5],
380+
[0.6, 0.5, 0.5],
381+
],
382+
dtype=torch.float32,
383+
)
384+
sim.forward_batch(batch)
385+
# Forward batch moves the entire tensor once, then each per-sample forward
386+
# moves the corresponding slice.
387+
assert calls["count"] == 1 + len(batch)
388+
389+
390+
def test_torch_simulator_sample_inputs_on_device(parameters_range, monkeypatch):
391+
sim = TorchMockSimulator(parameters_range, ["var1"], device="cpu")
392+
original_move = TorchSimulator._move_tensors_to_device
393+
calls = {"count": 0}
394+
395+
def recording_move(self, *args):
396+
calls["count"] += 1
397+
return original_move(self, *args)
398+
399+
monkeypatch.setattr(
400+
TorchSimulator,
401+
"_move_tensors_to_device",
402+
recording_move,
403+
)
404+
sim.sample_inputs(5)
405+
assert calls["count"] == 1

0 commit comments

Comments
 (0)