diff --git a/docs/source/tutorials/stallion.ipynb b/docs/source/tutorials/stallion.ipynb index 005dc5319..0461551e5 100644 --- a/docs/source/tutorials/stallion.ipynb +++ b/docs/source/tutorials/stallion.ipynb @@ -41,7 +41,6 @@ "source": [ "import warnings\n", "\n", - "\n", "warnings.filterwarnings(\"ignore\") # avoid printing out absolute paths" ] }, diff --git a/pytorch_forecasting/models/deepar/__deepar_pkg_v2.py b/pytorch_forecasting/models/deepar/__deepar_pkg_v2.py new file mode 100644 index 000000000..1d9878f3b --- /dev/null +++ b/pytorch_forecasting/models/deepar/__deepar_pkg_v2.py @@ -0,0 +1,117 @@ +""" +Packages container for DeepAR model. +""" + +from pytorch_forecasting.base._base_pkg import Base_pkg + + +class DeepAR_pkg_v2(Base_pkg): + """DeepAR package container.""" + + _tags = { + "info:name": "DeepAR", + "info:compute": 3, + "authors": ["jdb78"], + "capability:exogenous": True, + "capability:multivariate": True, + "capability:pred_int": True, + "capability:flexible_history_length": True, + "capability:cold_start": False, + } + + @classmethod + def get_cls(cls): + """Get model class.""" + from pytorch_forecasting.models.deepar._deepar_v2 import DeepAR + + return DeepAR + + @classmethod + def get_datamodule_cls(cls): + """Get the underlying DataModule class.""" + from pytorch_forecasting.data.data_module import ( + EncoderDecoderTimeSeriesDataModule, + ) + + return EncoderDecoderTimeSeriesDataModule + + @classmethod + def get_base_test_params(cls): + """Return testing parameter settings for the trainer.""" + return [ + {}, + dict( + cell_type="GRU", + hidden_size=16, + rnn_layers=2, + ), + ] + + @classmethod + def get_test_train_params(cls): + """Return testing parameter settings for the trainer. + + Returns + ------- + params : dict or list of dict, default = {} + Parameters to create testing instances of the class + Each dict are parameters to construct an "interesting" test instance, i.e., + `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. + `create_test_instance` uses the first (or only) dictionary in `params` + """ + from pytorch_forecasting.metrics import NormalDistributionLoss + + params = [ + dict( + loss=NormalDistributionLoss(), + ), + dict( + loss=NormalDistributionLoss(), + cell_type="GRU", + hidden_size=16, + rnn_layers=2, + ), + dict( + loss=NormalDistributionLoss(), + hidden_size=32, + rnn_layers=3, + dropout=0.2, + ), + dict( + loss=NormalDistributionLoss(), + hidden_size=20, + datamodule_cfg=dict( + max_encoder_length=7, + max_prediction_length=5, + ), + ), + dict( + loss=NormalDistributionLoss(), + hidden_size=16, + n_validation_samples=50, + n_plotting_samples=25, + ), + dict( + loss=NormalDistributionLoss(), + hidden_size=10, + rnn_layers=1, + dropout=0.0, + datamodule_cfg=dict( + max_encoder_length=3, + max_prediction_length=2, + ), + ), + ] + + default_dm_cfg = { + "max_encoder_length": 4, + "max_prediction_length": 3, + } + + for param in params: + current_dm_cfg = param.get("datamodule_cfg", {}) + merged_dm_cfg = default_dm_cfg.copy() + merged_dm_cfg.update(current_dm_cfg) + param["datamodule_cfg"] = merged_dm_cfg + + return params diff --git a/pytorch_forecasting/models/deepar/__init__.py b/pytorch_forecasting/models/deepar/__init__.py index 6b127059b..6998d48e5 100644 --- a/pytorch_forecasting/models/deepar/__init__.py +++ b/pytorch_forecasting/models/deepar/__init__.py @@ -1,6 +1,8 @@ """DeepAR: Probabilistic forecasting with autoregressive recurrent networks.""" +from pytorch_forecasting.models.deepar.__deepar_pkg_v2 import DeepAR_pkg_v2 from pytorch_forecasting.models.deepar._deepar import DeepAR from pytorch_forecasting.models.deepar._deepar_pkg import DeepAR_pkg +from pytorch_forecasting.models.deepar._deepar_v2 import DeepAR as DeepAR_v2 -__all__ = ["DeepAR", "DeepAR_pkg"] +__all__ = ["DeepAR", "DeepAR_v2", "DeepAR_pkg", "DeepAR_pkg_v2"] diff --git a/pytorch_forecasting/models/deepar/_deepar_v2.py b/pytorch_forecasting/models/deepar/_deepar_v2.py new file mode 100644 index 000000000..64cb09bb7 --- /dev/null +++ b/pytorch_forecasting/models/deepar/_deepar_v2.py @@ -0,0 +1,463 @@ +######################################################################################## +# Disclaimer: This implementation is based on the new version of data pipeline and is +# experimental, please use with care. +######################################################################################## + +from typing import Any, Literal, Optional, Union + +import torch +import torch.nn as nn +from torch.optim import Optimizer + +from pytorch_forecasting.metrics import ( + DistributionLoss, + MultiLoss, + NormalDistributionLoss, +) +from pytorch_forecasting.models.base._base_model_v2 import BaseModel +from pytorch_forecasting.models.nn import HiddenState, get_rnn +from pytorch_forecasting.utils import apply_to_list + + +class DeepAR(BaseModel): + """ + DeepAR: Probabilistic forecasting with autoregressive recurrent networks. + + Parameters + ---------- + loss : nn.Module + Loss function to use. + logging_metrics : list[nn.Module], optional + Metrics to log during training. + optimizer : Union[Optimizer, str], optional + Optimizer to use. Defaults to "adam". + optimizer_params : dict, optional + Parameters for the optimizer. + lr_scheduler : str, optional + Learning rate scheduler. + lr_scheduler_params : dict, optional + Parameters for the learning rate scheduler. + cell_type : Literal["LSTM", "GRU"], optional + Recurrent cell type ["LSTM", "GRU"]. Defaults to "LSTM". + hidden_size : int, optional + Hidden recurrent size. Defaults to 10. + rnn_layers : int, optional + Number of RNN layers. Defaults to 2. + dropout : float, optional + Dropout in RNN layers. Defaults to 0.1. + metadata : dict, optional + Metadata from the DataModule. + """ + + @classmethod + def _pkg(cls): + """Link to the package container which holds metadata and test + configurations.""" + from pytorch_forecasting.models.deepar.__deepar_pkg_v2 import ( + DeepAR_pkg_v2, + ) + + return DeepAR_pkg_v2 + + def __init__( + self, + loss: nn.Module | None = None, + logging_metrics: list[nn.Module] | None = None, + optimizer: Optimizer | str | None = "adam", + optimizer_params: dict | None = None, + lr_scheduler: str | None = None, + lr_scheduler_params: dict | None = None, + cell_type: Literal["LSTM", "GRU"] = "LSTM", + hidden_size: int = 10, + rnn_layers: int = 2, + dropout: float = 0.1, + metadata: dict | None = None, + output_transformer: Any = None, + **kwargs: Any, + ): + if loss is None: + loss = NormalDistributionLoss() + + if not isinstance(loss, (DistributionLoss, MultiLoss)): + raise ValueError( + f"DeepAR requires a 'DistributionLoss', but got {type(loss).__name__}. " + "SMAPE is not supported as the primary training loss for DeepAR." + ) + + super().__init__( + loss=loss, + logging_metrics=logging_metrics, + optimizer=optimizer, + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + lr_scheduler_params=lr_scheduler_params, + ) + self.save_hyperparameters(ignore=["loss", "logging_metrics", "metadata"]) + + self.cell_type = cell_type + self.hidden_size = hidden_size + self.rnn_layers = rnn_layers + self.dropout = dropout + self.metadata = metadata or {} + self.output_transformer = output_transformer + + self.max_encoder_length = self.metadata.get("max_encoder_length", 0) + self.max_prediction_length = self.metadata.get("max_prediction_length", 0) + + self.encoder_cont_dim = self.metadata.get("encoder_cont", 0) + self.encoder_cat_dim = self.metadata.get("encoder_cat", 0) + self.decoder_cont_dim = self.metadata.get("decoder_cont", 0) + self.decoder_cat_dim = self.metadata.get("decoder_cat", 0) + + self.target_dim = self.metadata.get("target_dim", 1) + + rnn_class = get_rnn(cell_type) + encoder_input_size = self.encoder_cont_dim + self.encoder_cat_dim + decoder_input_size = self.decoder_cont_dim + self.decoder_cat_dim + rnn_input_size = hidden_size + + self.encoder_projector = nn.Linear(encoder_input_size, rnn_input_size) + self.decoder_projector = nn.Linear(decoder_input_size, rnn_input_size) + + self.rnn = rnn_class( + input_size=rnn_input_size, + hidden_size=hidden_size, + num_layers=rnn_layers, + dropout=dropout if rnn_layers > 1 else 0, + batch_first=True, + ) + + if isinstance(self.loss, MultiLoss): + n_outputs = sum(len(l.distribution_arguments) for l in self.loss) + else: + n_outputs = len(self.loss.distribution_arguments) * self.target_dim + + self.distribution_projector = nn.Linear(hidden_size, n_outputs) + + self.target_positions = torch.arange(self.target_dim) + self.lagged_target_positions = {} + self.n_reals = self.encoder_cont_dim + self.n_categoricals = self.encoder_cat_dim + + @property + def output_transformer(self): + if hasattr(self, "_output_transformer"): + return self._output_transformer + return None + + @output_transformer.setter + def output_transformer(self, value): + self._output_transformer = value + + def on_fit_start(self): + """Auto-detect the normalizer from the DataModule when training starts.""" + if self.output_transformer is None: + if ( + hasattr(self.trainer, "datamodule") + and self.trainer.datamodule is not None + ): + if hasattr(self.trainer.datamodule, "target_normalizer"): + self.output_transformer = self.trainer.datamodule.target_normalizer + + if self.output_transformer is not None: + self.hparams.output_transformer = self.output_transformer + + def transform_output( + self, + prediction: torch.Tensor, + target_scale: torch.Tensor, + ) -> torch.Tensor: + """Apply scaling (from target_scale) back to the predicted parameters.""" + + class DummyEncoder: + transformation = None + center = False + + @property + def transform(self): + return self.transformation + + def __call__(self, x): + return x.get("prediction", 0) + + encoder = self.output_transformer + if encoder is None or isinstance(encoder, str): + encoder = DummyEncoder() + + if not isinstance(self.loss, MultiLoss) and self.target_dim > 1: + n_params = prediction.size(-1) // self.target_dim + batch_size, time_steps, _ = prediction.shape + + prediction = prediction.view( + batch_size, time_steps, self.target_dim, n_params + ) + prediction_flat = prediction.permute(0, 2, 1, 3).reshape( + -1, time_steps, n_params + ) + + if target_scale.ndim == 2: + target_scale = target_scale.unsqueeze(1) + target_scale_flat = target_scale.reshape(-1, 2) + + rescaled_flat = self.loss.rescale_parameters( + prediction_flat, target_scale=target_scale_flat, encoder=encoder + ) + + new_n_params = rescaled_flat.size(-1) + rescaled = rescaled_flat.view( + batch_size, self.target_dim, time_steps, new_n_params + ) + rescaled = rescaled.permute(0, 2, 1, 3) + return rescaled.reshape(batch_size, time_steps, -1) + + return self.loss.rescale_parameters( + prediction, target_scale=target_scale, encoder=encoder + ) + + def construct_input_vector( + self, + x_cat: torch.Tensor, + x_cont: torch.Tensor, + one_off_target: torch.Tensor = None, + is_encoder: bool = True, + ) -> torch.Tensor: + """Merges categoricals and continuous variables into a single vector + for the RNN.""" + + if self.n_reals > 0 and self.n_categoricals > 0: + input_vector = torch.cat([x_cont, x_cat], dim=-1) + elif self.n_reals > 0: + input_vector = x_cont.clone() + elif self.n_categoricals > 0: + input_vector = x_cat.clone() + else: + raise ValueError("No features found in input") + + input_vector[..., self.target_positions] = torch.roll( + input_vector[..., self.target_positions], shifts=1, dims=1 + ) + + if one_off_target is not None: + input_vector[:, 0, self.target_positions] = one_off_target.reshape( + input_vector.size(0), -1 + ) + else: + input_vector = input_vector[:, 1:] + + if is_encoder: + input_vector = self.encoder_projector(input_vector) + else: + input_vector = self.decoder_projector(input_vector) + + return input_vector + + def encode(self, x: dict[str, torch.Tensor]) -> HiddenState: + """Pass the encoder sequence through the RNN to get the initial hidden + state for decoding.""" + input_vector = self.construct_input_vector( + x["encoder_cat"], x["encoder_cont"], is_encoder=True + ) + _, hidden_state = self.rnn(input_vector) + return hidden_state + + def decode_all( + self, + x: torch.Tensor, + hidden_state: HiddenState, + lengths: torch.Tensor = None, + ): + """One-shot decoding pass through the RNN.""" + decoder_output, hidden_state = self.rnn(x, hidden_state) + output = self.distribution_projector(decoder_output) + return output, hidden_state + + def output_to_prediction( + self, + prediction_params: torch.Tensor, + target_scale: torch.Tensor, + n_samples: int = 1, + ): + """Converts raw distribution params to actual predictions by sampling + or scaling.""" + rescaled_params = self.transform_output( + prediction_params, target_scale=target_scale + ) + + if n_samples > 1: + prediction = self.loss.sample(rescaled_params, n_samples) + else: + prediction = self.loss.sample(rescaled_params, 1) + + if target_scale.ndim == 2 and target_scale.shape[-1] == 2: + center = target_scale[..., 0].unsqueeze(-1) + scale = target_scale[..., 1].unsqueeze(-1) + else: + if target_scale.shape[-1] == 2: + center = target_scale[..., 0].unsqueeze(-1) + scale = target_scale[..., 1].unsqueeze(-1) + else: + center = torch.zeros_like(target_scale).unsqueeze(-1) + scale = target_scale.unsqueeze(-1) + + while scale.ndim < prediction.ndim: + scale = scale.unsqueeze(-1) + center = center.unsqueeze(-1) + + normalized_prediction = (prediction - center) / scale + return prediction, normalized_prediction + + def decode_autoregressive( + self, + decode_one: callable, + first_target: torch.Tensor, + first_hidden_state: Any, + target_scale: torch.Tensor, + n_decoder_steps: int, + n_samples: int = 1, + ) -> torch.Tensor: + """Loop through steps one by one, feeding previous prediction to the + next step.""" + output = [] + current_hidden_state = first_hidden_state + normalized_output = [first_target.unsqueeze(1)] + + for idx in range(n_decoder_steps): + prediction_params, current_hidden_state = decode_one( + idx, + lagged_targets=normalized_output, + hidden_state=current_hidden_state, + ) + rescaled, normalized = self.output_to_prediction( + prediction_params, target_scale, n_samples=n_samples + ) + + normalized_output.append(normalized.unsqueeze(1)) + output.append(rescaled) + + return torch.stack(output, dim=1) + + def decode( + self, + input_vector: torch.Tensor, + target_scale: torch.Tensor, + decoder_lengths: torch.Tensor, + hidden_state: HiddenState, + n_samples: int = None, + ) -> torch.Tensor: + """Interface for decoding, choosing between one-shot (training) and + AR (inference).""" + if n_samples is None: + output, _ = self.decode_all(input_vector, hidden_state) + output = self.transform_output(output, target_scale) + return output + else: + target_pos = self.target_positions + + def decode_one(idx, lagged_targets, hidden_state): + x = input_vector[:, [idx]] + lag_val = lagged_targets[-1].squeeze(1) + x[:, 0, target_pos] = lag_val + prediction_params, hidden_state = self.decode_all(x, hidden_state) + return prediction_params[:, 0], hidden_state + + return self.decode_autoregressive( + decode_one, + first_target=input_vector[:, 0, target_pos], + first_hidden_state=hidden_state, + target_scale=target_scale, + n_decoder_steps=input_vector.size(1), + n_samples=1, + ) + + def forward( + self, x: dict[str, torch.Tensor], n_samples: int = None + ) -> dict[str, torch.Tensor]: + """The main entry point for a forward pass (Batch -> Prediction).""" + + target_scale = x.get("target_scale") + if target_scale is None: + if self.target_dim > 1: + shape = (x["encoder_cont"].size(0), self.target_dim, 2) + else: + shape = (x["encoder_cont"].size(0), 2) + + target_scale = torch.zeros(shape, device=x["encoder_cont"].device) + target_scale[..., 1] = 1.0 + + if target_scale.ndim == 1: + target_scale = target_scale.unsqueeze(-1) + + if target_scale.ndim == 2: + if target_scale.shape[-1] == 2: + if self.target_dim > 1: + target_scale = target_scale.unsqueeze(1).expand( + -1, self.target_dim, -1 + ) + elif target_scale.shape[-1] == self.target_dim * 2: + target_scale = target_scale.view( + target_scale.size(0), self.target_dim, 2 + ) + elif target_scale.shape[-1] == self.target_dim: + target_scale = torch.stack( + [torch.zeros_like(target_scale), target_scale], dim=-1 + ) + + if ( + self.target_dim == 1 + and target_scale.ndim == 3 + and target_scale.shape[1] == 1 + ): + target_scale = target_scale.squeeze(1) + + if target_scale.ndim == 3 and target_scale.size(1) == 1 and self.target_dim > 1: + target_scale = target_scale.expand(-1, self.target_dim, -1) + + hidden_state = self.encode(x) + + last_encoder_target = x["encoder_cont"][ + torch.arange(x["encoder_cont"].size(0)).unsqueeze(-1), + (x["encoder_lengths"] - 1).unsqueeze(-1), + self.target_positions, + ] + + input_vector = self.construct_input_vector( + x["decoder_cat"], + x["decoder_cont"], + one_off_target=last_encoder_target, + is_encoder=False, + ) + + if self.training: + assert n_samples is None + + if n_samples is not None and n_samples > 1: + batch_size = input_vector.size(0) + input_vector = input_vector.repeat_interleave(n_samples, dim=0) + hidden_state = apply_to_list( + hidden_state, lambda t: t.repeat_interleave(n_samples, dim=0) + ) + target_scale = target_scale.repeat_interleave(n_samples, dim=0) + decode_samples = 1 + else: + decode_samples = n_samples + batch_size = input_vector.size(0) + + output = self.decode( + input_vector, + target_scale=target_scale, + decoder_lengths=x["decoder_lengths"], + hidden_state=hidden_state, + n_samples=decode_samples, + ) + + if n_samples is not None and n_samples > 1: + if output.ndim == 2: + output = output.view(batch_size, n_samples, -1).permute(0, 2, 1) + elif output.ndim == 3: + output = output.view(batch_size, n_samples, output.size(1), -1).permute( + 0, 2, 1, 3 + ) + if output.shape[-1] == 1: + output = output.squeeze(-1) + + return {"prediction": output} diff --git a/pytorch_forecasting/tests/test_all_v2/utils.py b/pytorch_forecasting/tests/test_all_v2/utils.py index a8cb714dc..db7bbc1ff 100644 --- a/pytorch_forecasting/tests/test_all_v2/utils.py +++ b/pytorch_forecasting/tests/test_all_v2/utils.py @@ -29,7 +29,9 @@ def _setup_pkg_and_data( model_cfg = params_copy if "loss" not in model_cfg: - model_cfg["loss"] = SMAPE() + pred_types = estimator_cls.get_class_tag("info:pred_type", []) + if "distr" not in pred_types: + model_cfg["loss"] = SMAPE() default_datamodule_cfg = { "train_val_test_split": (0.8, 0.2), diff --git a/tests/test_models/test_deepar_v2.py b/tests/test_models/test_deepar_v2.py new file mode 100644 index 000000000..51509626d --- /dev/null +++ b/tests/test_models/test_deepar_v2.py @@ -0,0 +1,158 @@ +from typing import Literal + +import pytest +import torch +import torch.nn as nn + +from pytorch_forecasting.metrics import MAE, SMAPE, NormalDistributionLoss +from pytorch_forecasting.models.deepar._deepar_v2 import DeepAR as DeepAR_v2 + +BATCH_SIZE_TEST = 2 +MAX_ENCODER_LENGTH_TEST = 10 +MAX_PREDICTION_LENGTH_TEST = 5 +HIDDEN_SIZE_TEST = 8 +RNN_LAYERS_TEST = 1 +DROPOUT_TEST = 0.1 + + +def get_default_test_metadata( + enc_cont=2, + enc_cat=1, + dec_cont=2, + dec_cat=1, + target_dim=1, +): + return { + "max_encoder_length": MAX_ENCODER_LENGTH_TEST, + "max_prediction_length": MAX_PREDICTION_LENGTH_TEST, + "encoder_cont": enc_cont, + "encoder_cat": enc_cat, + "decoder_cont": dec_cont, + "decoder_cat": dec_cat, + "target_dim": target_dim, + } + + +def create_deepar_input_batch_for_test( + metadata, batch_size=BATCH_SIZE_TEST, device="cpu" +): + """Create a synthetic input batch dictionary for testing DeepAR forward passes.""" + x = { + "encoder_cont": torch.randn( + batch_size, + metadata["max_encoder_length"], + metadata.get("encoder_cont", 0), + device=device, + ), + "encoder_cat": torch.randn( + batch_size, + metadata["max_encoder_length"], + metadata.get("encoder_cat", 0), + device=device, + ), + "decoder_cont": torch.randn( + batch_size, + metadata["max_prediction_length"], + metadata.get("decoder_cont", 0), + device=device, + ), + "decoder_cat": torch.randn( + batch_size, + metadata["max_prediction_length"], + metadata.get("decoder_cat", 0), + device=device, + ), + "encoder_lengths": torch.full( + (batch_size,), + metadata["max_encoder_length"], + dtype=torch.long, + device=device, + ), + "decoder_lengths": torch.full( + (batch_size,), + metadata["max_prediction_length"], + dtype=torch.long, + device=device, + ), + } + return x + + +@pytest.fixture +def deepar_model_params_fixture(): + """Create basic model parameters for DeepAR.""" + return { + "loss": NormalDistributionLoss(), + "hidden_size": HIDDEN_SIZE_TEST, + "rnn_layers": RNN_LAYERS_TEST, + "dropout": DROPOUT_TEST, + "cell_type": "LSTM", + } + + +def test_deepar_v2_initialization(deepar_model_params_fixture): + """Test basic initialization of the DeepAR V2 model.""" + metadata = get_default_test_metadata() + model = DeepAR_v2(**deepar_model_params_fixture, metadata=metadata) + + assert model.hidden_size == HIDDEN_SIZE_TEST + assert model.rnn_layers == RNN_LAYERS_TEST + assert model.max_encoder_length == MAX_ENCODER_LENGTH_TEST + assert model.max_prediction_length == MAX_PREDICTION_LENGTH_TEST + assert model.encoder_cont_dim == metadata["encoder_cont"] + assert model.target_dim == metadata["target_dim"] + + +@pytest.mark.parametrize("cell_type", ["LSTM", "GRU"]) +def test_deepar_v2_forward_pass(deepar_model_params_fixture, cell_type): + """Test DeepAR V2 forward pass with different cell types.""" + metadata = get_default_test_metadata() + params = deepar_model_params_fixture.copy() + params["cell_type"] = cell_type + + model = DeepAR_v2(**params, metadata=metadata) + model.eval() + + x = create_deepar_input_batch_for_test(metadata) + output = model(x) + + assert "prediction" in output + prediction = output["prediction"] + + n_dist_params = len(params["loss"].distribution_arguments) + expected_params = n_dist_params + 2 + assert prediction.shape == ( + BATCH_SIZE_TEST, + MAX_PREDICTION_LENGTH_TEST, + metadata["target_dim"] * expected_params, + ) + assert not torch.isnan(prediction).any() + + +def test_deepar_v2_multi_target(deepar_model_params_fixture): + """Test DeepAR V2 forward pass with multiple targets.""" + target_dim = 3 + metadata = get_default_test_metadata(target_dim=target_dim, enc_cont=target_dim) + model = DeepAR_v2(**deepar_model_params_fixture, metadata=metadata) + model.eval() + + x = create_deepar_input_batch_for_test(metadata) + output = model(x) + + prediction = output["prediction"] + n_dist_params = len(deepar_model_params_fixture["loss"].distribution_arguments) + + if isinstance(prediction, list): + assert len(prediction) == target_dim + for p in prediction: + assert p.shape == ( + BATCH_SIZE_TEST, + MAX_PREDICTION_LENGTH_TEST, + n_dist_params + 2, + ) + else: + assert prediction.shape == ( + BATCH_SIZE_TEST, + MAX_PREDICTION_LENGTH_TEST, + target_dim * (n_dist_params + 2), + )