Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
da9b21a
[MNT] fix end of file and ruff formatting
Meet-Ramjiyani-10 Mar 4, 2026
d47c878
Merge branch 'main' into enh/rnn-v2-interface
Meet-Ramjiyani-10 Mar 4, 2026
825b12d
[ENH] address review comments - remove _get_test_datamodule_from, upd…
Meet-Ramjiyani-10 Mar 4, 2026
8e26e9d
Merge branch 'enh/rnn-v2-interface' of https://github.com/Meet-Ramjiy…
Meet-Ramjiyani-10 Mar 4, 2026
3498d10
Merge branch 'main' into enh/rnn-v2-interface
Meet-Ramjiyani-10 Mar 4, 2026
e32bb6f
Merge remote-tracking branch 'upstream/main' into enh/rnn-v2-interface
Meet-Ramjiyani-10 Mar 5, 2026
2c3d47e
[ENH] export RecurrentNetwork_v2 and RecurrentNetwork_pkg_v2 from rnn…
Meet-Ramjiyani-10 Mar 5, 2026
57721fb
Merge branch 'enh/rnn-v2-interface' of https://github.com/Meet-Ramjiy…
Meet-Ramjiyani-10 Mar 5, 2026
ddb218d
[BUG] fix info:name tag in RecurrentNetwork_pkg_v2 to match class name
Meet-Ramjiyani-10 Mar 5, 2026
e15c963
[BUG] fix info:name tag and switch to EncoderDecoderTimeSeriesDataModule
Meet-Ramjiyani-10 Mar 5, 2026
d56dab1
[ENH] update RecurrentNetwork_v2 to use BaseModel and encoder_cont keys
Meet-Ramjiyani-10 Mar 5, 2026
fc53b3e
[BUG] fix _pkg method to use renamed package class
Meet-Ramjiyani-10 Mar 5, 2026
df73720
[BUG] fix info:name tag in RecurrentNetwork_pkg_v2 to match class name
Meet-Ramjiyani-10 Mar 5, 2026
014f63d
[BUG] fix info:name tag in RecurrentNetwork_pkg_v2 to match class name
Meet-Ramjiyani-10 Mar 5, 2026
da1025b
[ENH] rename RNN v2 class to RNN and pkg to RNN_pkg_v2 following v2 c…
Meet-Ramjiyani-10 Mar 5, 2026
e0b007a
Merge branch 'main' into enh/rnn-v2-interface
Meet-Ramjiyani-10 Mar 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions pytorch_forecasting/models/rnn/_rnn_pkg_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""
Package container for RecurrentNetwork v2 model.
"""

from pytorch_forecasting.base._base_pkg import Base_pkg


class RecurrentNetwork_pkg_v2(Base_pkg):
"""RecurrentNetwork v2 package container."""

_tags = {
"info:name": "RecurrentNetwork",
"info:compute": 2,
"authors": ["Meet-Ramjiyani-10"],
"capability:exogenous": True,
"capability:multivariate": True,
"capability:pred_int": True,
"capability:flexible_history_length": True,
"capability:cold_start": False,
}

@classmethod
def get_cls(cls):
"""Get model class."""
from pytorch_forecasting.models.rnn._rnn_v2 import RecurrentNetwork_v2

return RecurrentNetwork_v2

@classmethod
def get_datamodule_cls(cls):
"""Get the underlying DataModule class."""
from pytorch_forecasting.data._tslib_data_module import TslibDataModule

return TslibDataModule

@classmethod
def _get_test_datamodule_from(cls, trainer_kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think you need this method anymore? Please have a look at other pkgs (like timexer) for more info

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for pointing. After checking _timexer_pkg_v2.py. this method is not needed . Will remove it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phoeenniixx done 👍

"""Create test dataloaders from trainer_kwargs."""
from pytorch_forecasting.data._tslib_data_module import TslibDataModule
from pytorch_forecasting.tests._data_scenarios import (
data_with_covariates_v2,
make_datasets_v2,
)

data_with_covariates = data_with_covariates_v2()
data_loader_default_kwargs = dict(
target="target",
group_ids=["agency_encoded", "sku_encoded"],
add_relative_time_idx=True,
)

data_loader_kwargs = trainer_kwargs.get("data_loader_kwargs", {})
data_loader_default_kwargs.update(data_loader_kwargs)

datasets_info = make_datasets_v2(
data_with_covariates, **data_loader_default_kwargs
)

training_dataset = datasets_info["training_dataset"]
validation_dataset = datasets_info["validation_dataset"]

context_length = data_loader_kwargs.get("context_length", 8)
prediction_length = data_loader_kwargs.get("prediction_length", 2)
batch_size = data_loader_kwargs.get("batch_size", 2)

train_datamodule = TslibDataModule(
time_series_dataset=training_dataset,
context_length=context_length,
prediction_length=prediction_length,
add_relative_time_idx=data_loader_kwargs.get("add_relative_time_idx", True),
batch_size=batch_size,
train_val_test_split=(0.8, 0.2, 0.0),
)

val_datamodule = TslibDataModule(
time_series_dataset=validation_dataset,
context_length=context_length,
prediction_length=prediction_length,
add_relative_time_idx=data_loader_kwargs.get("add_relative_time_idx", True),
batch_size=batch_size,
train_val_test_split=(0.0, 1.0, 0.0),
)

test_datamodule = TslibDataModule(
time_series_dataset=validation_dataset,
context_length=context_length,
prediction_length=prediction_length,
add_relative_time_idx=data_loader_kwargs.get("add_relative_time_idx", True),
batch_size=batch_size,
train_val_test_split=(0.0, 0.0, 1.0),
)

train_datamodule.setup("fit")
val_datamodule.setup("fit")
test_datamodule.setup("test")

train_dataloader = train_datamodule.train_dataloader()
val_dataloader = val_datamodule.val_dataloader()
test_dataloader = test_datamodule.test_dataloader()

return {
"train": train_dataloader,
"val": val_dataloader,
"test": test_dataloader,
"data_module": train_datamodule,
}

@classmethod
def get_test_train_params(cls):
"""
Return testing parameter settings for the trainer.
Returns
-------
params : list of dict
Parameters to create testing instances of the class.
"""
from pytorch_forecasting.metrics import MAE, SMAPE, QuantileLoss

params = [
{},
dict(cell_type="LSTM", hidden_size=32, rnn_layers=1),
dict(cell_type="GRU", hidden_size=32, rnn_layers=1),
dict(
cell_type="LSTM",
hidden_size=16,
rnn_layers=2,
dropout=0.1,
),
]

default_dm_cfg = {"context_length": 8, "prediction_length": 2}

for param in params:
current_dm_cfg = param.get("datamodule_cfg", {})
default_dm_cfg.update(current_dm_cfg)
param["datamodule_cfg"] = default_dm_cfg

return params
170 changes: 170 additions & 0 deletions pytorch_forecasting/models/rnn/_rnn_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""
Recurrent Network (LSTM/GRU) model for PyTorch Forecasting v2.
---------------------------------------------------------------
"""

from typing import Any

import torch
import torch.nn as nn
from torch.optim import Optimizer

from pytorch_forecasting.metrics import QuantileLoss
from pytorch_forecasting.models.base._tslib_base_model_v2 import TslibBaseModel


class RecurrentNetwork_v2(TslibBaseModel):
"""
Recurrent Network model for time series forecasting.

Supports LSTM and GRU cell types. Encodes the input sequence
using a recurrent layer and projects the final hidden state
to the prediction horizon.

Parameters
----------
loss : nn.Module
Loss function for training.
cell_type : str, optional
Recurrent cell type, either "LSTM" or "GRU". Default is "LSTM".
hidden_size : int, optional
Number of features in the hidden state. Default is 64.
rnn_layers : int, optional
Number of recurrent layers. Default is 2.
dropout : float, optional
Dropout rate between RNN layers. Default is 0.1.
logging_metrics : list[nn.Module], optional
Metrics to log during training. Default is None.
optimizer : str or Optimizer, optional
Optimizer to use. Default is "adam".
optimizer_params : dict, optional
Parameters for the optimizer. Default is None.
lr_scheduler : str, optional
Learning rate scheduler. Default is None.
lr_scheduler_params : dict, optional
Parameters for the scheduler. Default is None.
metadata : dict, optional
Metadata from TslibDataModule. Default is None.
"""

@classmethod
def _pkg(cls):
"""Package containing the model."""
from pytorch_forecasting.models.rnn._rnn_pkg_v2 import RecurrentNetwork_pkg_v2

return RecurrentNetwork_pkg_v2

def __init__(
self,
loss: nn.Module,
cell_type: str = "LSTM",
hidden_size: int = 64,
rnn_layers: int = 2,
dropout: float = 0.1,
logging_metrics: list[nn.Module] | None = None,
optimizer: Optimizer | str | None = "adam",
optimizer_params: dict | None = None,
lr_scheduler: str | None = None,
lr_scheduler_params: dict | None = None,
metadata: dict | None = None,
**kwargs: Any,
):
super().__init__(
loss=loss,
logging_metrics=logging_metrics,
optimizer=optimizer,
optimizer_params=optimizer_params,
lr_scheduler=lr_scheduler,
lr_scheduler_params=lr_scheduler_params,
metadata=metadata,
)

assert cell_type in (
"LSTM",
"GRU",
), f"cell_type must be 'LSTM' or 'GRU', got '{cell_type}'"

self.cell_type = cell_type
self.hidden_size = hidden_size
self.rnn_layers = rnn_layers
self.dropout = dropout

self.save_hyperparameters(ignore=["loss", "logging_metrics", "metadata"])

self.n_quantiles = None
if isinstance(loss, QuantileLoss):
self.n_quantiles = len(loss.quantiles)

self._init_network()

def _init_network(self):
"""Initialize the RNN network layers."""

input_size = self.cont_dim + self.target_dim

if self.cell_type == "LSTM":
self.rnn = nn.LSTM(
input_size=input_size,
hidden_size=self.hidden_size,
num_layers=self.rnn_layers,
dropout=self.dropout if self.rnn_layers > 1 else 0,
batch_first=True,
)
else:
self.rnn = nn.GRU(
input_size=input_size,
hidden_size=self.hidden_size,
num_layers=self.rnn_layers,
dropout=self.dropout if self.rnn_layers > 1 else 0,
batch_first=True,
)

if self.n_quantiles is not None:
output_size = self.prediction_length * self.n_quantiles
else:
output_size = self.prediction_length * self.target_dim

self.output_projector = nn.Linear(self.hidden_size, output_size)

def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""
Forward pass of the RecurrentNetwork model.

Parameters
----------
x : dict[str, torch.Tensor]
Dictionary containing input tensors.

Returns
-------
dict[str, torch.Tensor]
Dictionary containing output tensors with key "prediction".
"""
available_features = []

if "history_cont" in x and x["history_cont"].size(-1) > 0:
available_features.append(x["history_cont"])

if "history_target" in x and x["history_target"].size(-1) > 0:
available_features.append(x["history_target"])

if not available_features:
raise ValueError("No valid input features found in input dictionary.")

input_data = torch.cat(available_features, dim=-1)

rnn_out, _ = self.rnn(input_data)

last_hidden = rnn_out[:, -1, :]

output = self.output_projector(last_hidden)

if self.n_quantiles is not None:
output = output.reshape(-1, self.prediction_length, self.n_quantiles)
else:
output = output.reshape(-1, self.prediction_length, self.target_dim)

if "target_scale" in x and hasattr(self, "transform_output"):
output = self.transform_output(output, x["target_scale"])

return {"prediction": output}
Loading