Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 54 additions & 13 deletions pytorch_forecasting/models/base/_base_model_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def __init__(
lr_scheduler_params: dict | None = None,
):
super().__init__()
if isinstance(loss, list):
loss = MultiLoss(loss)
self.loss = loss
self.logging_metrics = logging_metrics if logging_metrics is not None else []
self.optimizer = optimizer
Expand All @@ -78,6 +80,13 @@ def pkg(cls):
"""Package class for the model."""
return cls._pkg()

@property
def n_targets(self) -> int:
"""Number of targets to forecast, based on the loss function."""
if isinstance(self.loss, MultiLoss):
return len(self.loss.metrics)
return 1

def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""
Forward pass of the model.
Expand Down Expand Up @@ -139,18 +148,20 @@ def predict(

return predict_callback.result

def to_prediction(self, out: dict[str, Any], **kwargs) -> torch.Tensor:
def to_prediction(
self, out: dict[str, Any], **kwargs
) -> torch.Tensor | list[torch.Tensor]:
"""Converts raw model output to point forecasts."""
# todo: add MultiLoss support
try:
out = self.loss.to_prediction(out["prediction"], **kwargs)
except TypeError: # in case passed kwargs do not exist
out = self.loss.to_prediction(out["prediction"])
return out

def to_quantiles(self, out: dict[str, Any], **kwargs) -> torch.Tensor:
def to_quantiles(
self, out: dict[str, Any], **kwargs
) -> torch.Tensor | list[torch.Tensor]:
"""Converts raw model output to quantile forecasts."""
# todo: add MultiLoss support
try:
out = self.loss.to_quantiles(out["prediction"], **kwargs)
except TypeError: # in case passed kwargs do not exist
Expand Down Expand Up @@ -178,9 +189,19 @@ def training_step(
x, y = batch
y_hat_dict = self(x)
y_hat = y_hat_dict["prediction"]
loss = self.loss(y_hat, y)
if isinstance(self.loss, MultiLoss):
if not isinstance(y_hat, list | tuple):
y_hat = [y_hat]
loss = self.loss(list(y_hat), y)
else:
loss = self.loss(y_hat, y)
self.log(
"train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
"train_loss",
loss,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)
self.log_metrics(y_hat, y, prefix="train")
return {"loss": loss}
Expand All @@ -206,7 +227,12 @@ def validation_step(
x, y = batch
y_hat_dict = self(x)
y_hat = y_hat_dict["prediction"]
loss = self.loss(y_hat, y)
if isinstance(self.loss, MultiLoss):
if not isinstance(y_hat, list | tuple):
y_hat = [y_hat]
loss = self.loss(list(y_hat), y)
else:
loss = self.loss(y_hat, y)
self.log(
"val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
)
Expand Down Expand Up @@ -234,7 +260,12 @@ def test_step(
x, y = batch
y_hat_dict = self(x)
y_hat = y_hat_dict["prediction"]
loss = self.loss(y_hat, y)
if isinstance(self.loss, MultiLoss):
if not isinstance(y_hat, list | tuple):
y_hat = [y_hat]
loss = self.loss(list(y_hat), y)
else:
loss = self.loss(y_hat, y)
self.log(
"test_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
)
Expand Down Expand Up @@ -344,20 +375,30 @@ def _get_scheduler(
raise ValueError(f"Scheduler {self.lr_scheduler} not supported.")

def log_metrics(
self, y_hat: torch.Tensor, y: torch.Tensor, prefix: str = "val"
self,
y_hat: torch.Tensor | list[torch.Tensor],
y: torch.Tensor | tuple,
prefix: str = "val",
) -> None:
"""
Log additional metrics during training, validation, or testing.

Parameters
----------
y_hat : torch.Tensor
Predicted output tensor.
y : torch.Tensor
Target output tensor.
y_hat : torch.Tensor or list of torch.Tensor
Predicted output tensor, or list of per-target tensors for MultiLoss.
y : torch.Tensor or tuple
Target output tensor, or ``(targets, weights)`` tuple for MultiLoss.
prefix : str
Prefix for the logged metrics (e.g., "train", "val", "test").
"""
if not self.logging_metrics:
return

if isinstance(self.loss, MultiLoss):
# Currently skipped to avoid type mismatch errors (list vs tensor).
return

for metric in self.logging_metrics:
metric_value = metric(y_hat, y)
self.log(
Expand Down
26 changes: 13 additions & 13 deletions pytorch_forecasting/models/samformer/_samformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def __init__(
# specific params
hidden_size: int,
use_revin: bool,
# out_channels has to be 1, due to lack of MultiLoss support in v2.
out_channels: int | list[int] | None = 1,
persistence_weight: float = 0.0,
logging_metrics: list[nn.Module] | None = None,
Expand Down Expand Up @@ -77,14 +76,10 @@ def __init__(
self.max_encoder_length = self.metadata["max_encoder_length"]
self.max_prediction_length = self.metadata["max_prediction_length"]
self.encoder_cont = self.metadata["encoder_cont"]
self.encoder_input_dim = self.encoder_cont + 1 # +1 for target variable input.

self.encoder_input_dim = self.encoder_cont + self.n_targets

self.hidden_size = hidden_size
if out_channels != 1:
raise ValueError(
"out_channels has to be 1 for Samformer,",
" due to lack of MultiLoss support in v2.",
)
self.out_channels = out_channels
self.use_revin = use_revin
self.persistence_weight = persistence_weight
Expand Down Expand Up @@ -172,15 +167,20 @@ def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:

out = out.transpose(1, 2)

target_predictions = out[:, :, -1] # (batch_size, max_prediction_length)

if target_predictions.ndim == 1:
target_predictions = target_predictions.unsqueeze(0)
target_predictions = out[:, :, -self.n_targets :]

if self.n_quantiles > 1:
target_predictions = target_predictions.unsqueeze(-1).expand(
-1, -1, self.n_quantiles
-1, -1, -1, self.n_quantiles
)
elif self.n_quantiles == 1:
else:
target_predictions = target_predictions.unsqueeze(-1)

if self.n_targets > 1:
target_predictions = [
target_predictions[:, :, i, :] for i in range(self.n_targets)
]
else:
target_predictions = target_predictions.squeeze(2)

return {"prediction": target_predictions}
Loading