diff --git a/autoemulate/core/compare.py b/autoemulate/core/compare.py index b763d8eb2..3c6c5c85b 100644 --- a/autoemulate/core/compare.py +++ b/autoemulate/core/compare.py @@ -71,6 +71,7 @@ def __init__( log_level: str = "progress_bar", tuning_metric: str | Metric = "r2", evaluation_metrics: list[str | Metric] | None = None, + n_samples: int = 1000, ): """ Initialize the AutoEmulate class. @@ -130,6 +131,9 @@ def __init__( Each entry can be a string shortcut or a MetricConfig object. IMPORTANT: The first metric in the list is used to determine the best model. + n_samples: int + Number of samples to generate to predict mean when emulator does not have a + mean directly available. Defaults to 1000. """ Results.__init__(self) self.random_seed = random_seed @@ -187,6 +191,7 @@ def __init__( # Set up logger and ModelSerialiser for saving models self.logger, self.progress_bar = get_configured_logger(log_level) self.model_serialiser = ModelSerialiser(self.logger) + self.n_samples = n_samples # Run compare self.compare() @@ -417,6 +422,9 @@ def compare(self): n_splits=self.n_splits, shuffle=self.shuffle, transformed_emulator_params=self.transformed_emulator_params, + metric_params=MetricParams( + n_samples=self.n_samples + ), ) mean_scores = [ np.mean(score).item() for score in scores @@ -484,7 +492,9 @@ def compare(self): n_bootstraps=self.n_bootstraps, device=self.device, metrics=self.evaluation_metrics, - metric_params=MetricParams(y_train=train_val_y), + metric_params=MetricParams( + n_samples=self.n_samples, y_train=train_val_y + ), ) test_metrics = bootstrap( transformed_emulator, @@ -493,7 +503,9 @@ def compare(self): n_bootstraps=self.n_bootstraps, device=self.device, metrics=self.evaluation_metrics, - metric_params=MetricParams(y_train=train_val_y), + metric_params=MetricParams( + n_samples=self.n_samples, y_train=train_val_y + ), ) # Log all test metrics from test_metrics dictionary diff --git a/autoemulate/core/model_selection.py b/autoemulate/core/model_selection.py index 782c67a80..8cccaaec9 100644 --- a/autoemulate/core/model_selection.py +++ b/autoemulate/core/model_selection.py @@ -1,5 +1,6 @@ import inspect import logging +from dataclasses import replace import torch from sklearn.model_selection import BaseCrossValidator @@ -61,6 +62,7 @@ def cross_validate( device: DeviceLike = "cpu", random_seed: int | None = None, metrics: list[Metric] | None = None, + metric_params: MetricParams | None = None, ): """ Cross validate model performance using the given `cv` strategy. @@ -85,6 +87,8 @@ def cross_validate( Optional random seed for reproducibility. metrics: list[TorchMetrics] | None List of metrics to compute. If None, uses r2 and rmse. + metric_params: MetricParams | None + Additional parameters to pass to the metrics. Defaults to None. Returns ------- @@ -94,6 +98,7 @@ def cross_validate( transformed_emulator_params = transformed_emulator_params or {} x_transforms = x_transforms or [] y_transforms = y_transforms or [] + metric_params = metric_params or MetricParams() # Setup metrics if metrics is None: @@ -143,7 +148,13 @@ def cross_validate( # compute and save results y_pred = transformed_emulator.predict(x_val) for metric in metrics: - score = evaluate(y_pred, y_val, metric) + score = evaluate( + # Update metric_params with y_train in case required by metric + y_pred, + y_val, + metric, + metric_params=replace(metric_params, y_train=y), + ) cv_results[metric.name].append(score) return cv_results diff --git a/autoemulate/core/tuner.py b/autoemulate/core/tuner.py index b3443e8b8..6c8f3c7fd 100644 --- a/autoemulate/core/tuner.py +++ b/autoemulate/core/tuner.py @@ -6,7 +6,7 @@ from torch.distributions import Transform from autoemulate.core.device import TorchDeviceMixin -from autoemulate.core.metrics import Metric, get_metric +from autoemulate.core.metrics import Metric, MetricParams, get_metric from autoemulate.core.model_selection import cross_validate from autoemulate.core.types import ( DeviceLike, @@ -74,6 +74,7 @@ def run( n_splits: int = 5, seed: int | None = None, shuffle: bool = True, + metric_params: MetricParams | None = None, ) -> tuple[list[list[float]], list[ModelParams]]: """ Run randomised hyperparameter search for a given model. @@ -97,6 +98,8 @@ def run( shuffle: bool Whether to shuffle data before splitting into cross validation folds. Defaults to True. + metric_params: MetricParams | None + Additional parameters to pass to the metrics. Defaults to None. Returns ------- @@ -130,6 +133,7 @@ def run( device=self.device, random_seed=None, metrics=[self.tuning_metric], + metric_params=metric_params, ) # Reset retries following a successful cross_validation call diff --git a/autoemulate/data/utils.py b/autoemulate/data/utils.py index 25914e7c1..da9ee5d94 100644 --- a/autoemulate/data/utils.py +++ b/autoemulate/data/utils.py @@ -192,6 +192,42 @@ def _denormalize( ) -> TensorLike: return (x * x_std) + x_mean + def output_to_tensor( + self, + output: OutputLike, + n_samples: int = 1000, + with_grad: bool = False, + ) -> torch.Tensor: + """Convert an output to a tensor (returns the mean if output is a distribution). + + Parameters + ---------- + output: OutputLike + The output to convert to a tensor. + n_samples: int + Number of samples to draw from the distribution. Defaults to 1000. + with_grad: bool + Whether to enable gradient calculation. Defaults to False. + + Returns + ------- + TensorLike + Tensor of shape `(n_batch, n_targets)` as input or the mean of the output if + output is a distribution. + """ + if isinstance(output, TensorLike): + return output + try: + return output.mean + except Exception: + # Use sampling to get a mean if mean property not available + samples = ( + output.rsample(torch.Size([n_samples])) + if with_grad + else output.sample(torch.Size([n_samples])) + ) + return samples.mean(dim=0) + def set_random_seed(seed: int = 42, deterministic: bool = True): """ diff --git a/autoemulate/emulators/__init__.py b/autoemulate/emulators/__init__.py index 27738b322..21a8a6fb6 100644 --- a/autoemulate/emulators/__init__.py +++ b/autoemulate/emulators/__init__.py @@ -1,4 +1,5 @@ from .base import Emulator +from .conformal import ConformalMLP from .ensemble import EnsembleMLP, EnsembleMLPDropout from .gaussian_process.exact import ( GaussianProcessCorrelatedMatern32, @@ -26,6 +27,7 @@ __all__ = [ "MLP", + "ConformalMLP", "Emulator", "EnsembleMLP", "EnsembleMLPDropout", diff --git a/autoemulate/emulators/base.py b/autoemulate/emulators/base.py index 462335b7d..fc38cf6b2 100644 --- a/autoemulate/emulators/base.py +++ b/autoemulate/emulators/base.py @@ -39,9 +39,19 @@ class Emulator(ABC, ValidationMixin, ConversionMixin, TorchDeviceMixin): supports_uq: bool = False @abstractmethod - def _fit(self, x: TensorLike, y: TensorLike): ... + def _fit( + self, + x: TensorLike, + y: TensorLike, + validation_data: tuple[TensorLike, TensorLike] | None = None, + ): ... - def fit(self, x: TensorLike, y: TensorLike): + def fit( + self, + x: TensorLike, + y: TensorLike, + validation_data: tuple[TensorLike, TensorLike] | None = None, + ): """Fit the emulator to the provided data.""" # Ensure x and y are tensors and 2D x, y = self._convert_to_tensors(x, y) @@ -58,7 +68,7 @@ def fit(self, x: TensorLike, y: TensorLike): y = self.y_transform(y) if self.y_transform is not None else y # Fit emulator - self._fit(x, y) + self._fit(x, y, validation_data) self.is_fitted_ = True @abstractmethod @@ -152,18 +162,7 @@ def predict_mean( """ x = self._ensure_with_grad(x, with_grad) y_pred = self._predict(x, with_grad) - if isinstance(y_pred, TensorLike): - return y_pred - try: - return y_pred.mean - except Exception: - # Use sampling to get a mean if mean property not available - samples = ( - y_pred.rsample(torch.Size([n_samples])) - if with_grad - else y_pred.sample(torch.Size([n_samples])) - ) - return samples.mean(dim=0) + return self.output_to_tensor(y_pred, n_samples) def predict_mean_and_variance( self, x: TensorLike, with_grad: bool = False, n_samples: int = 100 @@ -559,7 +558,12 @@ def loss_func(self, y_pred, y_true): """Loss function to be used for training the model.""" return nn.MSELoss()(y_pred, y_true) - def _fit(self, x: TensorLike, y: TensorLike): + def _fit( + self, + x: TensorLike, + y: TensorLike, + validation_data: tuple[TensorLike, TensorLike] | None = None, # noqa: ARG002 + ): """ Train a PyTorchBackend model. @@ -683,7 +687,12 @@ class SklearnBackend(DeterministicEmulator): def _model_specific_check(self, x: NumpyLike, y: NumpyLike): _, _ = x, y - def _fit(self, x: TensorLike, y: TensorLike): + def _fit( + self, + x: TensorLike, + y: TensorLike, + validation_data: tuple[TensorLike, TensorLike] | None = None, # noqa: ARG002 + ): if self.normalize_y: y, y_mean, y_std = self._normalize(y) self.y_mean = y_mean diff --git a/autoemulate/emulators/conformal.py b/autoemulate/emulators/conformal.py new file mode 100644 index 000000000..1a2752ead --- /dev/null +++ b/autoemulate/emulators/conformal.py @@ -0,0 +1,676 @@ +import math +import sys +from collections.abc import Callable +from typing import Literal + +import torch +from torch import nn +from torch.optim.lr_scheduler import LRScheduler + +from autoemulate.core.device import TorchDeviceMixin +from autoemulate.core.types import DeviceLike, DistributionLike, TensorLike, TuneParams +from autoemulate.emulators.base import Emulator, PyTorchBackend +from autoemulate.emulators.nn.mlp import MLP, _generate_mlp_docstring + + +class QuantileLoss(nn.Module): + """Quantile loss for quantile regression. + + This loss function asymmetrically penalizes over- and under-predictions, enabling + the model to learn specific quantiles of the conditional distribution. + """ + + def __init__(self, quantile: float): + """Initialize quantile loss. + + Parameters + ---------- + quantile: float + Target quantile level in (0, 1). For example, 0.1 for 10th percentile, 0.5 + for median, 0.9 for 90th percentile. + """ + super().__init__() + if not 0 < quantile < 1: + msg = f"Quantile must be in (0, 1), got {quantile}" + raise ValueError(msg) + self.quantile = quantile + + def forward(self, y_pred: TensorLike, y_true: TensorLike) -> TensorLike: + """Compute quantile loss. + + Parameters + ---------- + y_pred: TensorLike + Predicted values. + y_true: TensorLike + True target values. + + Returns + ------- + TensorLike + Scalar loss value. + """ + errors = y_true - y_pred + # Mean across batch and targets + return torch.max(self.quantile * errors, (self.quantile - 1) * errors).mean() + + +class QuantileMLP(MLP): + """MLP with quantile loss for quantile regression.""" + + def __init__(self, quantile: float, **kwargs): + """Initialize quantile MLP. + + Parameters + ---------- + quantile: float + Target quantile level in (0, 1). + **kwargs + Keyword arguments passed to MLP parent class. + """ + super().__init__(**kwargs) + self.quantile = quantile + self.quantile_loss = QuantileLoss(quantile) + + def loss_func(self, y_pred, y_true): + """Quantile loss function.""" + return self.quantile_loss(y_pred, y_true) + + +class Conformal(Emulator): + """Conformal Uncertainty Quantification (UQ) wrapper for emulators. + + This class wraps a base emulator to provide conformal prediction intervals with + guaranteed frequentist coverage. + + Both standard split conformal and Conformalized Quantile Regression (CQR) methods + are supported. + + Conformalized Quantile Regression (CQR) is defaultly implemented with two neural net + quantile regressors predicting lower and upper quantiles, followed by a calibration + step to ensure valid coverage. Note the _fit_quantile_regressors method can be + overridden to implement custom quantile regressors. + + Additional methods for input-dependent intervals (such as scaling) can be + implemented by adding further supported "method" strings and providing corresponding + logic in the _fit and _predict methods. + + References + ---------- + - Romano, Y., Patterson, E., & Candes, E. (2019). Conformalized Quantile Regression. + In Advances in Neural Information Processing Systems (Vol. 32). + https://papers.nips.cc/paper/8613-conformalized-quantile-regression.pdf + + """ + + supports_uq = True + + def __init__( + self, + emulator: Emulator, + alpha: float = 0.95, + device: DeviceLike | None = None, + calibration_ratio: float = 0.2, + n_samples: int = 1000, + method: Literal["constant", "quantile"] = "constant", + to_distribution: Callable[ + [TensorLike | None, tuple[TensorLike, TensorLike]], DistributionLike + ] = lambda _mean, bounds: torch.distributions.Uniform( + torch.min(bounds[0], bounds[1]), torch.max(bounds[0], bounds[1]) + ), + quantile_emulator_kwargs: dict | None = None, + ): + """Initialize a conformal emulator. + + Parameters + ---------- + emulator: Emulator + Base emulator to wrap for conformal UQ. + alpha: float + Desired predictive coverage level (e.g., 0.95 for 95% coverage). Must be in + (0, 1). + device: DeviceLike | None + Device to run the model on (e.g., "cpu", "cuda"). Defaults to None. + calibration_ratio: float + Fraction of the training data to reserve for calibration if explicit + validation data is not provided. Must lie in (0, 1). Defaults to 0.2. + n_samples: int + Number of samples used for sampling-based predictions or internal + procedures. Defaults to 1000. + method: Literal["constant", "quantile"] + Conformalization method to use: + - "constant": Standard split conformal with constant-width intervals + - "quantile": Conformalized Quantile Regression (CQR) with input-dependent + intervals. Defaults to "constant". + to_distribution: Callable[[TensorLike | None, tuple[TensorLike, TensorLike]], DistributionLike] + A callable that takes an optional mean and a tuple of lower and upper bounds + as input and returns a distribution over that interval. + Defaults to lambda _mean, bounds: torch.distributions.Uniform(bounds[0], bounds[1]). + quantile_emulator_kwargs: dict | None + Additional keyword arguments for the quantile emulators when + method="quantile". Defaults to None. + """ # noqa: E501 + self.emulator = emulator + self.supports_grad = emulator.supports_grad + if not 0 < alpha < 1: + msg = "Conformal coverage level alpha must be in (0, 1)." + raise ValueError(msg) + if not 0 < calibration_ratio < 1: + msg = "Calibration ratio must lie strictly between 0 and 1." + raise ValueError(msg) + if method not in {"constant", "quantile"}: + msg = f"Method must be 'constant' or 'quantile', got '{method}'." + raise ValueError(msg) + self.alpha = alpha # desired predictive coverage (e.g., 0.95) + self.calibration_ratio = calibration_ratio + self.n_samples = n_samples + self.method = method + self.to_distribution = to_distribution + self.quantile_emulator_kwargs = quantile_emulator_kwargs or {} + TorchDeviceMixin.__init__(self, device=device) + self.supports_grad = emulator.supports_grad + + @staticmethod + def is_multioutput() -> bool: + """Ensemble supports multi-output.""" + return True + + @staticmethod + def get_tune_params() -> TuneParams: + """Return a dictionary of hyperparameters to tune.""" + return {} + + def _fit( + self, + x: TensorLike, + y: TensorLike, + validation_data: tuple[TensorLike, TensorLike] | None = None, + ): + x_train, y_train = x, y + + # If not validation data passed, take random permutation of training data and + # hold out a calibration set according to calibration_ratio + if validation_data is None: + n_samples = x.shape[0] + if n_samples < 2: + msg = "At least two samples are required to create a calibration split." + raise ValueError(msg) + + n_cal = max(1, math.ceil(n_samples * self.calibration_ratio)) + if n_cal >= n_samples: + n_cal = n_samples - 1 + perm = torch.randperm(n_samples, device=x.device) + cal_idx = perm[:n_cal] + train_idx = perm[n_cal:] + if train_idx.numel() == 0: + msg = "Calibration split left no samples for training." + raise ValueError(msg) + x_cal = x[cal_idx] + y_true_cal = y[cal_idx] + x_train = x[train_idx] + y_train = y[train_idx] + else: + x_cal, y_true_cal = validation_data + + # Fit the base emulator + self.emulator.fit(x_train, y_train, validation_data=None) + + n_cal = x_cal.shape[0] + # Check calibration data is non-empty + if n_cal == 0: + msg = "Calibration set must contain at least one sample." + raise ValueError(msg) + + with torch.no_grad(): + # Predict and calculate residuals + y_pred_cal = self.output_to_tensor(self.emulator.predict(x_cal)) + + # Standard split conformal for constant-width intervals + if self.method == "constant": + # Compute absolute residuals + residuals = torch.abs(y_true_cal - y_pred_cal) + + # Apply finite-sample correction to quantile level + quantile_level = min(1.0, math.ceil((n_cal + 1) * self.alpha) / n_cal) + + # Calibrate over the batch dim with a separate quantile per output + self.q = torch.quantile(residuals, quantile_level, dim=0) + + # Conformalized Quantile Regression for input-dependent intervals + elif self.method == "quantile": + # Train quantile regressors + self._fit_quantile_regressors(x_train, y_train, x_cal, y_true_cal) + + self.is_fitted_ = True + + def _fit_quantile_regressors( + self, + x_train: TensorLike, + y_train: TensorLike, + x_cal: TensorLike, + y_true_cal: TensorLike, + ): + """Fit quantile regressors for CQR method. + + Trains two quantile regressors to predict lower and upper quantiles, + then calibrates the width using the calibration set. + """ + # Calculate quantile levels + lower_q = (1 - self.alpha) / 2 + upper_q = 1 - lower_q + + # Lower quantile emulator + self.lower_quantile_emulator = QuantileMLP( + lower_q, + x=x_train, + y=y_train, + device=self.device, + **self.quantile_emulator_kwargs, + ) + + # Upper quantile emulator + self.upper_quantile_emulator = QuantileMLP( + upper_q, + x=x_train, + y=y_train, + device=self.device, + **self.quantile_emulator_kwargs, + ) + + # Fit the quantile emulators + self.lower_quantile_emulator.fit(x_train, y_train, validation_data=None) + self.upper_quantile_emulator.fit(x_train, y_train, validation_data=None) + + # Predict quantiles on calibration set + with torch.no_grad(): + lower_pred_cal = self.output_to_tensor( + self.lower_quantile_emulator.predict(x_cal) + ) + upper_pred_cal = self.output_to_tensor( + self.upper_quantile_emulator.predict(x_cal) + ) + + # Calculate conformalization scores (non-conformity scores) + # For CQR, the score is max(lower - y, y - upper) + scores = torch.maximum( + lower_pred_cal - y_true_cal, y_true_cal - upper_pred_cal + ) + + # Apply finite-sample correction + n_cal = x_cal.shape[0] + quantile_level = min(1.0, math.ceil((n_cal + 1) * self.alpha) / n_cal) + + # Compute the correction term per output dimension + self.q_cqr = torch.quantile(scores, quantile_level, dim=0) + + def _predict(self, x: TensorLike, with_grad: bool) -> DistributionLike: + # Standard split conformal: constant width intervals + if self.method == "constant": + pred = self.emulator.predict(x, with_grad) + mean = self.output_to_tensor(pred) + q = self.q.to(mean.device) + return torch.distributions.Independent( + self.to_distribution(None, (mean - q, mean + q)), + reinterpreted_batch_ndims=mean.ndim - 1, + ) + + # Conformalized Quantile Regression: input-dependent intervals + if self.method == "quantile": + lower_pred = self.output_to_tensor( + self.lower_quantile_emulator.predict(x, with_grad) + ) + upper_pred = self.output_to_tensor( + self.upper_quantile_emulator.predict(x, with_grad) + ) + q_cqr = self.q_cqr.to(lower_pred.device) + + # Apply calibration correction + lower_bound = lower_pred - q_cqr + upper_bound = upper_pred + q_cqr + + # Return uniform distribution over the calibrated interval + return torch.distributions.Independent( + self.to_distribution(None, (lower_bound, upper_bound)), + reinterpreted_batch_ndims=lower_bound.ndim - 1, + ) + + msg = f"Unknown method: {self.method}" + raise ValueError(msg) + + +class ConformalMLP(Conformal, PyTorchBackend): + """Conformal UQ with an MLP. + + This class is to provides UQ via conformal prediction intervals wrapped around a + Multi-Layer Perceptron (MLP) emulator. + + Both standard split conformal and Conformalized Quantile Regression (CQR) methods + are supported. + + """ + + def __init__( + self, + x: TensorLike, + y: TensorLike, + standardize_x: bool = True, + standardize_y: bool = True, + activation_cls: type[nn.Module] = nn.ReLU, + loss_fn_cls: type[nn.Module] = nn.MSELoss, + epochs: int = 100, + batch_size: int = 16, + layer_dims: list[int] | None = None, + weight_init: str = "default", + scale: float = 1.0, + bias_init: str = "default", + dropout_prob: float | None = None, + lr: float = 1e-2, + params_size: int = 1, + random_seed: int | None = None, + device: DeviceLike | None = None, + scheduler_cls: type[LRScheduler] | None = None, + scheduler_params: dict | None = None, + alpha: float = 0.95, + calibration_ratio: float = 0.2, + method: Literal["constant", "quantile"] = "constant", + quantile_emulator_kwargs: dict | None = None, + ): + nn.Module.__init__(self) + + # Construct docstring + conformal_kwargs = """ + alpha: float + Desired predictive coverage level forwarded to the conformal wrapper. + calibration_ratio: float + Fraction of training samples to hold out for calibration when an explicit + validation set is not provided. + method: Literal["constant", "quantile"] + Conformalization method: + - "constant": Standard split conformal (constant-width intervals) + - "quantile": Conformalized Quantile Regression (input-dependent intervals) + Defaults to "constant". + quantile_emulator_kwargs: dict | None + Additional keyword arguments for the quantile emulators when + method="quantile". Defaults to None. + """ + conformal_mlp_params = _generate_mlp_docstring( + additional_parameters_docstring=conformal_kwargs, + default_dropout_prob=None, + ) + self.__doc__ = ( + """ Initialize a conformal MLP emulator.\n\n""" + conformal_mlp_params + ) + + emulator = MLP( + x, + y, + standardize_x=standardize_x, + standardize_y=standardize_y, + device=device, + activation_cls=activation_cls, + loss_fn_cls=loss_fn_cls, + epochs=epochs, + batch_size=batch_size, + layer_dims=layer_dims, + weight_init=weight_init, + scale=scale, + bias_init=bias_init, + dropout_prob=dropout_prob, + lr=lr, + params_size=params_size, + random_seed=random_seed, + scheduler_cls=scheduler_cls, + scheduler_params=scheduler_params, + ) + + quantile_defaults = { + "standardize_x": standardize_x, + "standardize_y": standardize_y, + "activation_cls": activation_cls, + "loss_fn_cls": loss_fn_cls, + "epochs": epochs, + "batch_size": batch_size, + "layer_dims": layer_dims, + "weight_init": weight_init, + "scale": scale, + "bias_init": bias_init, + "dropout_prob": dropout_prob, + "lr": lr, + "params_size": params_size, + "random_seed": random_seed, + "scheduler_cls": scheduler_cls, + "scheduler_params": scheduler_params, + } + merged_quantile_kwargs = { + **quantile_defaults, + **(quantile_emulator_kwargs or {}), + } + Conformal.__init__( + self, + emulator=emulator, + alpha=alpha, + device=device, + calibration_ratio=calibration_ratio, + method=method, + quantile_emulator_kwargs=merged_quantile_kwargs, + ) + + @staticmethod + def is_multioutput() -> bool: + """Ensemble of MLPs supports multi-output.""" + return True + + @staticmethod + def get_tune_params() -> TuneParams: + """Return a dictionary of hyperparameters to tune.""" + return MLP.get_tune_params() + + +def create_conformal_subclass( + name: str, + conformal_mlp_base_class: type[ConformalMLP], + method: Literal["constant", "quantile"], + auto_register: bool = True, + overwrite: bool = True, + **fixed_kwargs, +) -> type[ConformalMLP]: + """ + Create a subclass of ConformalMLP with given fixed_kwargs. + + This function creates a subclass of ConformalMLP where certain parameters + are fixed to specific values, reducing the parameter space for tuning. + + The created subclass is automatically registered with the main emulator Registry + (unless auto_register=False), making it discoverable by AutoEmulate. + + Parameters + ---------- + name: str + Name for the created subclass. + conformal_mlp_base_class: type[ConformalMLP] + Base class to inherit from (typically ConformalMLP). + method: Literal["constant", "quantile"] + Conformalization method to use in the subclass. + auto_register : bool + Whether to automatically register the created subclass with the main emulator + Registry. Defaults to True. + overwrite : bool + Whether to allow overwriting an existing class with the same name in the + main Registry. Useful for interactive development in notebooks. Defaults to + True. + **fixed_kwargs + Keyword arguments to fix in the subclass. These parameters will be + set to the provided values and excluded from hyperparameter tuning. + + Returns + ------- + type[ConformalMLP] + A new subclass of ConformalMLP with the specified parameters fixed. + The returned class can be pickled and used like any other GP emulator. + + Raises + ------ + ValueError + If `name` matches `model_name()` or `short_name()` of an already registered + emulator in the main Registry and `overwrite=False`. + + Notes + ----- + - Fixed parameters are automatically excluded from `get_tune_params()` to prevent + them from being included in hyperparameter optimization. + - Pickling: The created subclass is registered in the caller's module namespace, + ensuring it can be pickled and unpickled correctly even when created in downstream + code that uses autoemulate as a dependency. + - If auto_register=True (default), the class is also added to the main Registry. + """ + standardize_x = fixed_kwargs.get("standardize_x", True) + standardize_y = fixed_kwargs.get("standardize_y", True) + activation_cls: type[nn.Module] = fixed_kwargs.get("activation_cls", nn.ReLU) + loss_fn_cls: type[nn.Module] = fixed_kwargs.get("loss_fn_cls", nn.MSELoss) + epochs: int = fixed_kwargs.get("epochs", 100) + batch_size: int = fixed_kwargs.get("batch_size", 16) + layer_dims: list[int] | None = fixed_kwargs.get("layer_dims") + weight_init: str = fixed_kwargs.get("weight_init", "default") + scale: float = fixed_kwargs.get("scale", 1.0) + bias_init: str = fixed_kwargs.get("bias_init", "default") + dropout_prob: float | None = fixed_kwargs.get("dropout_prob") + lr: float = fixed_kwargs.get("lr", 1e-2) + params_size: int = fixed_kwargs.get("params_size", 1) + random_seed: int | None = fixed_kwargs.get("random_seed") + device: DeviceLike | None = fixed_kwargs.get("device") + scheduler_cls: type[LRScheduler] | None = fixed_kwargs.get("scheduler_cls") + scheduler_params: dict | None = fixed_kwargs.get("scheduler_params") + alpha: float = fixed_kwargs.get("alpha", 0.95) + calibration_ratio: float = fixed_kwargs.get("calibration_ratio", 0.2) + quantile_emulator_kwargs: dict | None = fixed_kwargs.get("quantile_emulator_kwargs") + + class ConformalMLPSubclass(conformal_mlp_base_class): + def __init__( + self, + x: TensorLike, + y: TensorLike, + standardize_x: bool = standardize_x, + standardize_y: bool = standardize_y, + activation_cls: type[nn.Module] = activation_cls, + loss_fn_cls: type[nn.Module] = loss_fn_cls, + epochs: int = epochs, + batch_size: int = batch_size, + layer_dims: list[int] | None = layer_dims, + weight_init: str = weight_init, + scale: float = scale, + bias_init: str = bias_init, + dropout_prob: float | None = dropout_prob, + lr: float = lr, + params_size: int = params_size, + random_seed: int | None = random_seed, + device: DeviceLike | None = device, + scheduler_cls: type[LRScheduler] | None = scheduler_cls, + scheduler_params: dict | None = scheduler_params, + alpha: float = alpha, + calibration_ratio: float = calibration_ratio, + method: Literal["constant", "quantile"] = method, + quantile_emulator_kwargs: dict | None = quantile_emulator_kwargs, + ): + super().__init__( + x, + y, + standardize_x=standardize_x, + standardize_y=standardize_y, + activation_cls=activation_cls, + loss_fn_cls=loss_fn_cls, + epochs=epochs, + batch_size=batch_size, + layer_dims=layer_dims, + weight_init=weight_init, + scale=scale, + bias_init=bias_init, + dropout_prob=dropout_prob, + lr=lr, + params_size=params_size, + random_seed=random_seed, + device=device, + scheduler_cls=scheduler_cls, + scheduler_params=scheduler_params, + alpha=alpha, + calibration_ratio=calibration_ratio, + method=method, + quantile_emulator_kwargs=quantile_emulator_kwargs, + ) + + @staticmethod + def get_tune_params(): + """Get tunable parameters, excluding those that are fixed.""" + tune_params = conformal_mlp_base_class.get_tune_params() + # Remove fixed parameters from tuning + tune_params.pop("method", None) + for key in fixed_kwargs: + tune_params.pop(key, None) + return tune_params + + # Create a more descriptive docstring that includes fixed parameters + method_and_fixed_kwargs = { + **fixed_kwargs, + } + fixed_params_str = "\n ".join( + f"- {k} = {v.__name__ if callable(v) else v}" + for k, v in method_and_fixed_kwargs.items() + ) + + ConformalMLPSubclass.__doc__ = f""" + {conformal_mlp_base_class.__doc__} + + Notes + ----- + {name} is a subclass of {conformal_mlp_base_class.__name__} and has the following + parameters set during initialization: + {fixed_params_str} + + For any parameters set with this approach, they are also excluded from the search + space when tuning. For example, if the `method` is set to `constant`, + the "constant" method will always be used as the `method`. Note that in this case + the associated hyperparameters (such as lengthscale) will still be fitted during + model training and are not fixed. + """ + + # Determine the caller's module for proper pickling support. + # When called from autoemulate itself, use __name__. + # When called from user code, use the caller's module + caller_frame = sys._getframe(1) + caller_module_name = caller_frame.f_globals.get("__name__", __name__) + + # Set the class name and module + ConformalMLPSubclass.__name__ = name + ConformalMLPSubclass.__qualname__ = name + ConformalMLPSubclass.__module__ = caller_module_name + + # Register class in the caller's module globals for pickling + # This ensures the class can be pickled/unpickled correctly + caller_frame.f_globals[name] = ConformalMLPSubclass + # Also register in the caller's module if it's a real module (not __main__) + if caller_module_name in sys.modules and caller_module_name != "__main__": + setattr(sys.modules[caller_module_name], name, ConformalMLPSubclass) + + # Automatically register with the main emulator Registry if requested + if auto_register: + # Lazy import to avoid circular dependency with __init__.py + from autoemulate.emulators import register # noqa: PLC0415 + + register(ConformalMLPSubclass, overwrite=overwrite) + + return ConformalMLPSubclass + + +# Built-in GP subclasses - auto_register=False as already registered in Registry init: +# autoemulate/emulators/__init__.py +ConformalMLPConstant = create_conformal_subclass( + "ConformalMLPConstant", + ConformalMLP, + method="constant", + auto_register=False, +) +ConformalMLPQuantile = create_conformal_subclass( + "ConformalMLPQuantile", + ConformalMLP, + method="quantile", + auto_register=False, +) diff --git a/autoemulate/emulators/ensemble.py b/autoemulate/emulators/ensemble.py index 78da0b724..6bd71a3a4 100644 --- a/autoemulate/emulators/ensemble.py +++ b/autoemulate/emulators/ensemble.py @@ -59,9 +59,14 @@ def get_tune_params() -> TuneParams: """Return a dictionary of hyperparameters to tune.""" return {} - def _fit(self, x: TensorLike, y: TensorLike) -> None: + def _fit( + self, + x: TensorLike, + y: TensorLike, + validation_data: tuple[TensorLike, TensorLike] | None = None, + ) -> None: for e in self.emulators: - e.fit(x, y) + e.fit(x, y, validation_data=validation_data) self.is_fitted_ = True def _predict(self, x: Tensor, with_grad: bool) -> GaussianLike: @@ -159,16 +164,13 @@ def __init__( n_emulators: int Number of MLP emulators to create in the ensemble. Defaults to 4. """ - self.__doc__ = f""" - Initialize an ensemble of MLPs. - - { - _generate_mlp_docstring( + self.__doc__ = ( + " Initialize an ensemble of MLPs.\n" + + _generate_mlp_docstring( additional_parameters_docstring=additional_parameters_docstring, default_dropout_prob=None, ) - } - """ + ) emulators = [ MLP( @@ -261,9 +263,14 @@ def get_tune_params() -> TuneParams: "n_samples": [10, 20, 50, 100], } - def _fit(self, x: TensorLike, y: TensorLike) -> None: + def _fit( + self, + x: TensorLike, + y: TensorLike, + validation_data: tuple[TensorLike, TensorLike] | None = None, + ) -> None: # Delegate training to the wrapped model - self.model.fit(x, y) + self.model.fit(x, y, validation_data=validation_data) self.is_fitted_ = True def _predict(self, x: Tensor, with_grad: bool) -> GaussianLike: @@ -332,15 +339,12 @@ def __init__( scheduler_cls: type[LRScheduler] | None = None, scheduler_params: dict | None = None, ): - self.__doc__ = f""" - Initialize an ensemble of MLPs with dropout. - - { - _generate_mlp_docstring( + self.__doc__ = ( + " Initialize an ensemble of MLPs with dropout.\n" + + _generate_mlp_docstring( additional_parameters_docstring="", default_dropout_prob=0.2 ) - } - """ + ) DropoutEnsemble.__init__( self, MLP( diff --git a/autoemulate/emulators/gaussian_process/exact.py b/autoemulate/emulators/gaussian_process/exact.py index 2c0536549..23d16b63b 100644 --- a/autoemulate/emulators/gaussian_process/exact.py +++ b/autoemulate/emulators/gaussian_process/exact.py @@ -189,7 +189,12 @@ def forward(self, x: TensorLike): MultivariateNormal(mean, covar) ) - def _fit(self, x: TensorLike, y: TensorLike): + def _fit( + self, + x: TensorLike, + y: TensorLike, + validation_data: tuple[TensorLike, TensorLike] | None = None, # noqa: ARG002 + ): self.train() self.likelihood.train() diff --git a/autoemulate/emulators/lightgbm.py b/autoemulate/emulators/lightgbm.py index cc60cd375..7a09dd110 100644 --- a/autoemulate/emulators/lightgbm.py +++ b/autoemulate/emulators/lightgbm.py @@ -157,7 +157,12 @@ def is_multioutput() -> bool: """LightGBM does not support multi-output.""" return False - def _fit(self, x: TensorLike, y: TensorLike): # type: ignore since this is valid subclass of types + def _fit( + self, + x: TensorLike, + y: TensorLike, + validation_data: tuple[TensorLike, TensorLike] | None = None, # noqa: ARG002 + ) -> None: x_np, y_np = self._convert_to_numpy(x, y) self.n_features_in_ = x_np.shape[1] self.model_.fit(x_np, y_np) diff --git a/autoemulate/emulators/nn/mlp.py b/autoemulate/emulators/nn/mlp.py index f7d5fe47c..eb7be93db 100644 --- a/autoemulate/emulators/nn/mlp.py +++ b/autoemulate/emulators/nn/mlp.py @@ -125,17 +125,15 @@ def __init__( scheduler_cls: type[LRScheduler] | None = None, scheduler_params: dict | None = None, ): - self.__doc__ = f""" - Multi-Layer Perceptron (MLP) emulator. - - MLP provides a simple deterministic emulator with optional model stochasticity - provided by different weight initialization and dropout. - { - _generate_mlp_docstring( + self.__doc__ = ( + " Multi-Layer Perceptron (MLP) emulator.\n\n" + " MLP provides a simple deterministic emulator with optional model\n" + " stochasticity provided by different weight initialization and " + " dropout.\n" + + _generate_mlp_docstring( additional_parameters_docstring="", default_dropout_prob=None ) - } - """ + ) TorchDeviceMixin.__init__(self, device=device) nn.Module.__init__(self) diff --git a/autoemulate/emulators/radial_basis_functions.py b/autoemulate/emulators/radial_basis_functions.py index b0417a406..03b9cccf3 100644 --- a/autoemulate/emulators/radial_basis_functions.py +++ b/autoemulate/emulators/radial_basis_functions.py @@ -69,7 +69,12 @@ def __init__( self.degree = degree self.device = device - def _fit(self, x: TensorLike, y: TensorLike): + def _fit( + self, + x: TensorLike, + y: TensorLike, + validation_data: tuple[TensorLike, TensorLike] | None = None, # noqa: ARG002 + ): self.model = RBFInterpolator( x, y, diff --git a/autoemulate/emulators/registry.py b/autoemulate/emulators/registry.py index 3cfb55a2f..ff7def01b 100644 --- a/autoemulate/emulators/registry.py +++ b/autoemulate/emulators/registry.py @@ -3,6 +3,8 @@ import torch +from autoemulate.emulators.conformal import ConformalMLPConstant, ConformalMLPQuantile + from .base import Emulator, GaussianProcessEmulator from .ensemble import EnsembleMLP, EnsembleMLPDropout from .gaussian_process.exact import ( @@ -50,6 +52,8 @@ def __init__(self): GaussianProcessCorrelatedMatern32, GaussianProcessCorrelatedRBF, EnsembleMLPDropout, + ConformalMLPConstant, + ConformalMLPQuantile, ] self._pytorch_emulators: list[type[Emulator]] = [ diff --git a/autoemulate/emulators/transformed/base.py b/autoemulate/emulators/transformed/base.py index 2665c5f2b..99ae5d401 100644 --- a/autoemulate/emulators/transformed/base.py +++ b/autoemulate/emulators/transformed/base.py @@ -356,7 +356,12 @@ def _inv_transform_y_distribution(self, y_t: DistributionLike) -> DistributionLi """ return TransformedDistribution(y_t, [ComposeTransform(self.y_transforms).inv]) - def _fit(self, x: TensorLike, y: TensorLike): + def _fit( + self, + x: TensorLike, + y: TensorLike, + validation_data: tuple[TensorLike, TensorLike] | None = None, # noqa: ARG002 + ): # Transform x and y x_t = self._transform_x(x) y_t = self._transform_y_tensor(y) diff --git a/autoemulate/experimental/emulators/fno.py b/autoemulate/experimental/emulators/fno.py index 086290868..04d4a1de9 100644 --- a/autoemulate/experimental/emulators/fno.py +++ b/autoemulate/experimental/emulators/fno.py @@ -58,7 +58,12 @@ def __init__( def is_multioutput() -> bool: # noqa: D102 return True - def _fit(self, x: TensorLike | DataLoader, y: TensorLike | None = None): + def _fit( + self, + x: TensorLike | DataLoader, + y: TensorLike | None = None, + validation_data: tuple[TensorLike, TensorLike] | None = None, # noqa: ARG002 + ): assert isinstance(x, DataLoader), "x currently must be a DataLoader" assert y is None, "y currently must be None" diff --git a/autoemulate/experimental/emulators/spatiotemporal.py b/autoemulate/experimental/emulators/spatiotemporal.py index e5a181d62..c17d94bd4 100644 --- a/autoemulate/experimental/emulators/spatiotemporal.py +++ b/autoemulate/experimental/emulators/spatiotemporal.py @@ -11,7 +11,12 @@ class SpatioTemporalEmulator(PyTorchBackend): channels: tuple[int, ...] - def fit(self, x: TensorLike | DataLoader, y: TensorLike | None = None): + def fit( + self, + x: TensorLike | DataLoader, + y: TensorLike | None = None, + validation_data: tuple[TensorLike, TensorLike] | None = None, # noqa: ARG002 + ): """Train a spatio-temporal emulator. Parameters @@ -30,7 +35,12 @@ def fit(self, x: TensorLike | DataLoader, y: TensorLike | None = None): raise RuntimeError(msg) @abstractmethod - def _fit(self, x: TensorLike | DataLoader, y: TensorLike | None = None): ... + def _fit( + self, + x: TensorLike | DataLoader, + y: TensorLike | None = None, + validation_data: tuple[TensorLike, TensorLike] | None = None, + ): ... def predict( self, diff --git a/tests/core/test_model_selection.py b/tests/core/test_model_selection.py index eb352fcef..24abce87a 100644 --- a/tests/core/test_model_selection.py +++ b/tests/core/test_model_selection.py @@ -19,7 +19,7 @@ def __init__(self, x=None, y=None, device=None, **kwargs): TorchDeviceMixin.__init__(self, device=device) _, _ = x, y - def _fit(self, x, y): + def _fit(self, x, y, validation_data=None): pass def _predict(self, x, with_grad=False): diff --git a/tests/emulators/test_base.py b/tests/emulators/test_base.py index ad9b5a8dd..e8574c89a 100644 --- a/tests/emulators/test_base.py +++ b/tests/emulators/test_base.py @@ -79,6 +79,34 @@ def test_fit(self): assert len(self.model.loss_history) == 10 assert all(isinstance(loss, float) for loss in self.model.loss_history) + def test_fit_with_validation_data(self): + x_train = torch.Tensor(np.array([[1.0], [2.0], [3.0]])) + y_train = torch.Tensor(np.array([[2.0], [4.0], [6.0]])) + x_val = torch.Tensor(np.array([[4.0]])) + y_val = torch.Tensor(np.array([[8.0]])) + + class DummyModelWithValidationData(self.DummyModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.validation_data_seen = None + + def _fit(self, x, y, validation_data=None): + self.validation_data_seen = validation_data + return super()._fit(x, y, validation_data=validation_data) + + model = DummyModelWithValidationData( + scheduler_cls=ExponentialLR, + scheduler_params={"gamma": 0.9}, + ) + + model.fit(x_train, y_train, validation_data=(x_val, y_val)) + + assert model.is_fitted_ + assert model.validation_data_seen is not None + val_x, val_y = model.validation_data_seen + assert torch.equal(val_x, x_val) + assert torch.equal(val_y, y_val) + def test_predict(self): """ Test the predict method of PyTorchBackend. diff --git a/tests/emulators/test_conformal.py b/tests/emulators/test_conformal.py new file mode 100644 index 000000000..65842fe46 --- /dev/null +++ b/tests/emulators/test_conformal.py @@ -0,0 +1,145 @@ +import torch +from autoemulate.core.types import DistributionLike, TensorLike +from autoemulate.emulators.conformal import ConformalMLP + + +def test_conformal_mlp(): + def f(x): + return torch.sin(x) + + # Training data + x_train = torch.rand(100, 3) * 10 + y_train = f(x_train) + + # Calibration data + x_cal, y_cal = torch.rand(100, 3) * 10, f(torch.rand(100, 3) * 10) + + emulator = ConformalMLP(x_train, y_train, layer_dims=[100, 100], lr=1e-2) + emulator.fit(x_train, y_train, validation_data=(x_cal, y_cal)) + + # Test + x_test = torch.linspace(0.0, 15.0, steps=1000).repeat(1, 3).reshape(-1, 3) + y_test_hat = emulator.predict(x_test) + assert isinstance(y_test_hat, DistributionLike) + assert isinstance(y_test_hat.mean, TensorLike) + assert isinstance(y_test_hat.variance, TensorLike) + assert y_test_hat.mean.shape == (1000, 3) + assert y_test_hat.variance.shape == (1000, 3) + assert not y_test_hat.mean.requires_grad + + y_test_hat_grad = emulator.predict(x_test, with_grad=True) + assert y_test_hat_grad.mean.requires_grad # type: ignore # noqa: PGH003 + + +def test_conformal_mlp_quantile_method(): + """Test Conformalized Quantile Regression (CQR) method.""" + + def f(x): + return torch.sin(x) + + # Training data + x_train = torch.rand(100, 3) * 10 + y_train = f(x_train) + + # Calibration data + x_cal, y_cal = torch.rand(100, 3) * 10, f(torch.rand(100, 3) * 10) + + emulator = ConformalMLP( + x_train, + y_train, + method="quantile", + layer_dims=[100, 100], + lr=1e-2, + epochs=50, + quantile_emulator_kwargs={"epochs": 50, "lr": 1e-2}, + ) + emulator.fit(x_train, y_train, validation_data=(x_cal, y_cal)) + + # Test + x_test = torch.linspace(0.0, 15.0, steps=100).repeat(1, 3).reshape(-1, 3) + y_test_hat = emulator.predict(x_test) + assert isinstance(y_test_hat, DistributionLike) + assert isinstance(y_test_hat.mean, TensorLike) + assert isinstance(y_test_hat.variance, TensorLike) + assert y_test_hat.mean.shape == (100, 3) + assert y_test_hat.variance.shape == (100, 3) + assert not y_test_hat.mean.requires_grad + + # Check that intervals vary across input space (not constant width) + interval_widths = y_test_hat.variance.sqrt() * 2 # approximate width + # Variance should differ across inputs for quantile method + assert interval_widths.std() > 0, "Intervals should vary across input space" + + +def test_conformal_methods_comparison(): + """Compare constant vs quantile conformal methods on heteroscedastic data.""" + + def heteroscedastic_function(x: TensorLike) -> tuple[TensorLike, TensorLike]: + """Function with heteroscedastic noise (variance depends on x).""" + mean = torch.sin(2 * x) + # Variance increases with x + noise_std = 0.1 + 0.3 * (x / 10.0).abs() + noise = torch.randn_like(mean) * noise_std + return mean, mean + noise + + # Generate training data with heteroscedastic noise + torch.manual_seed(42) + x_train = torch.rand(100, 1) * 10 - 5 + _, y_train = heteroscedastic_function(x_train) + + # Generate calibration data + x_cal = torch.rand(50, 1) * 10 - 5 + _, y_cal = heteroscedastic_function(x_cal) + + # Generate test data + x_test = torch.linspace(-5, 5, 50).reshape(-1, 1) + + # Test constant method + model_constant = ConformalMLP( + x_train, + y_train, + method="constant", + alpha=0.90, + layer_dims=[32, 16], + epochs=50, + lr=1e-2, + ) + model_constant.fit(x_train, y_train, validation_data=(x_cal, y_cal)) + pred_constant = model_constant.predict(x_test) + + # Test quantile method + model_quantile = ConformalMLP( + x_train, + y_train, + method="quantile", + alpha=0.90, + layer_dims=[32, 16], + epochs=50, + lr=1e-2, + quantile_emulator_kwargs={"epochs": 50, "lr": 1e-2}, + ) + model_quantile.fit(x_train, y_train, validation_data=(x_cal, y_cal)) + pred_quantile = model_quantile.predict(x_test) + + # Compare interval widths + with torch.no_grad(): + # Uniform distribution bounds provide interval limits directly + constant_base = pred_constant.base_dist # type: ignore[attr-defined] + quantile_base = pred_quantile.base_dist # type: ignore[attr-defined] + + width_constant = (constant_base.high - constant_base.low).squeeze() + width_quantile = (quantile_base.high - quantile_base.low).squeeze() + + # Constant conformal should have constant widths (std ≈ 0) + assert width_constant.std() < 1e-6, ( + "Constant method should have constant interval widths" + ) + + # Quantile conformal should have variable widths (std > 0) + assert width_quantile.std() > 0, ( + "Quantile method should have variable interval widths" + ) + + # Both methods should produce valid predictions + assert pred_constant.mean.shape == x_test.shape # type: ignore[attr-defined] + assert pred_quantile.mean.shape == x_test.shape # type: ignore[attr-defined] diff --git a/tests/emulators/test_grads.py b/tests/emulators/test_grads.py index 28ccbf610..20440b428 100644 --- a/tests/emulators/test_grads.py +++ b/tests/emulators/test_grads.py @@ -4,6 +4,7 @@ import torch from autoemulate.core.types import TensorLike from autoemulate.emulators import GAUSSIAN_PROCESS_EMULATORS, PYTORCH_EMULATORS +from autoemulate.emulators.conformal import Conformal, ConformalMLP from autoemulate.emulators.gaussian_process.exact import GaussianProcess from autoemulate.emulators.transformed.base import TransformedEmulator from autoemulate.transforms.pca import PCATransform @@ -18,8 +19,8 @@ def get_pytest_param_yof(model, x_t, y_t, o, f): - return ( - pytest.param( + if o and f and model.supports_uq: + return pytest.param( model, x_t, y_t, @@ -30,9 +31,21 @@ def get_pytest_param_yof(model, x_t, y_t, o, f): reason="Full covariance sampling not implemented", ), ) - if (o and f and model.supports_uq) - else (model, x_t, y_t, o, f) - ) + + if (not o) and issubclass(model, Conformal): + return pytest.param( + model, + x_t, + y_t, + o, + f, + marks=pytest.mark.xfail( + raises=ValueError, + reason="Conformal emulators require sampling for predictions", + ), + ) + + return (model, x_t, y_t, o, f) def get_parametrize_cases(): @@ -52,7 +65,8 @@ def get_parametrize_cases(): output_from_samples_and_full_covariance_cases_cases = [ get_pytest_param_yof(model, x_t, y_t, o, f) for model, x_t, y_t, o, f in itertools.product( - [GaussianProcess], + # ConformalMLP also included here as not tested above since + [GaussianProcess, ConformalMLP], X_TRANSFORMS, Y_TRANSFORMS, [False, True], diff --git a/tests/emulators/test_mlp.py b/tests/emulators/test_mlp.py index 27bd8a89c..be8132316 100644 --- a/tests/emulators/test_mlp.py +++ b/tests/emulators/test_mlp.py @@ -71,9 +71,9 @@ def test_mlp_predict_deterministic_with_seed(sample_data_y2d, new_data_y2d): model3.fit(x, y) pred3 = model3.predict(x2) - assert isinstance(pred1, torch.Tensor) - assert isinstance(pred2, torch.Tensor) - assert isinstance(pred3, torch.Tensor) + assert isinstance(pred1, TensorLike) + assert isinstance(pred2, TensorLike) + assert isinstance(pred3, TensorLike) assert torch.allclose(pred1, pred2) msg = "Predictions should differ with different seeds." assert not torch.allclose(pred1, pred3), msg