|
5 | 5 | from scipy.stats import qmc |
6 | 6 | from tqdm import tqdm |
7 | 7 |
|
| 8 | +from autoemulate.core.device import TorchDeviceMixin |
8 | 9 | from autoemulate.core.logging_config import get_configured_logger |
9 | | -from autoemulate.core.types import TensorLike |
| 10 | +from autoemulate.core.types import DeviceLike, TensorLike |
10 | 11 | from autoemulate.data.utils import ValidationMixin, set_random_seed |
11 | 12 |
|
12 | 13 | logger = logging.getLogger("autoemulate") |
@@ -362,3 +363,99 @@ def get_outputs_as_dict(self) -> dict[str, TensorLike]: |
362 | 363 | output_dict[output_name] = self.results_tensor[:, i] |
363 | 364 |
|
364 | 365 | return output_dict |
| 366 | + |
| 367 | + |
| 368 | +class TorchSimulator(Simulator, TorchDeviceMixin): |
| 369 | + """ |
| 370 | + Simulator that runs computations on a specified torch device. |
| 371 | +
|
| 372 | + This subclass extends :class:`Simulator` with the :class:`TorchDeviceMixin` |
| 373 | + so that simulators implemented in PyTorch (e.g., ``torchcor``) can run on |
| 374 | + CPU or accelerator hardware. Inputs are moved to ``self.device`` before the |
| 375 | + forward pass and the resulting tensors are kept on the same device. |
| 376 | + """ |
| 377 | + |
| 378 | + def __init__( |
| 379 | + self, |
| 380 | + parameters_range: dict[str, tuple[float, float]], |
| 381 | + output_names: list[str], |
| 382 | + log_level: str = "progress_bar", |
| 383 | + device: DeviceLike | None = None, |
| 384 | + ): |
| 385 | + Simulator.__init__(self, parameters_range, output_names, log_level) |
| 386 | + TorchDeviceMixin.__init__(self, device=device) |
| 387 | + |
| 388 | + def sample_inputs( |
| 389 | + self, n_samples: int, random_seed: int | None = None, method: str = "lhs" |
| 390 | + ) -> TensorLike: |
| 391 | + """ |
| 392 | + Sample inputs and move them to the simulator's device. |
| 393 | +
|
| 394 | + Parameters |
| 395 | + ---------- |
| 396 | + n_samples: int |
| 397 | + Number of samples to generate. |
| 398 | + random_seed: int | None |
| 399 | + Optional random seed to make sampling reproducible. |
| 400 | + method: str |
| 401 | + Sampling method, one of ``"lhs"`` or ``"sobol"``. |
| 402 | +
|
| 403 | + Returns |
| 404 | + ------- |
| 405 | + TensorLike |
| 406 | + Sampled inputs located on ``self.device``. |
| 407 | + """ |
| 408 | + samples = super().sample_inputs( |
| 409 | + n_samples, random_seed=random_seed, method=method |
| 410 | + ) |
| 411 | + (samples_device,) = self._move_tensors_to_device(samples) |
| 412 | + return samples_device |
| 413 | + |
| 414 | + def forward(self, x: TensorLike, allow_failures: bool = True) -> TensorLike | None: |
| 415 | + """ |
| 416 | + Run a single simulation on the configured device. |
| 417 | +
|
| 418 | + Parameters |
| 419 | + ---------- |
| 420 | + x: TensorLike |
| 421 | + Input tensor with shape ``(n_samples, in_dim)``. |
| 422 | + allow_failures: bool |
| 423 | + When True, failures return ``None`` instead of raising. |
| 424 | +
|
| 425 | + Returns |
| 426 | + ------- |
| 427 | + TensorLike | None |
| 428 | + Simulation result on ``self.device`` or ``None`` on failure. |
| 429 | + """ |
| 430 | + (x_device,) = self._move_tensors_to_device(x) |
| 431 | + y = super().forward(x_device, allow_failures=allow_failures) |
| 432 | + if isinstance(y, torch.Tensor): |
| 433 | + return y.to(self.device) |
| 434 | + return y |
| 435 | + |
| 436 | + def forward_batch( |
| 437 | + self, x: TensorLike, allow_failures: bool = True |
| 438 | + ) -> tuple[TensorLike, TensorLike]: |
| 439 | + """ |
| 440 | + Run a batch of simulations with device management. |
| 441 | +
|
| 442 | + Parameters |
| 443 | + ---------- |
| 444 | + x: TensorLike |
| 445 | + Batch of inputs with shape ``(batch_size, in_dim)``. |
| 446 | + allow_failures: bool |
| 447 | + Whether to skip failures (True) or raise immediately (False). |
| 448 | +
|
| 449 | + Returns |
| 450 | + ------- |
| 451 | + tuple[TensorLike, TensorLike] |
| 452 | + Tuple of ``(results, valid_inputs)`` residing on ``self.device``. |
| 453 | + """ |
| 454 | + (x_device,) = self._move_tensors_to_device(x) |
| 455 | + results, valid_inputs = super().forward_batch( |
| 456 | + x_device, allow_failures=allow_failures |
| 457 | + ) |
| 458 | + results = results.to(self.device) |
| 459 | + self.results_tensor = results |
| 460 | + valid_inputs = valid_inputs.to(self.device) |
| 461 | + return results, valid_inputs |
0 commit comments