Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
da9b21a
[MNT] fix end of file and ruff formatting
Meet-Ramjiyani-10 Mar 4, 2026
d47c878
Merge branch 'main' into enh/rnn-v2-interface
Meet-Ramjiyani-10 Mar 4, 2026
825b12d
[ENH] address review comments - remove _get_test_datamodule_from, upd…
Meet-Ramjiyani-10 Mar 4, 2026
8e26e9d
Merge branch 'enh/rnn-v2-interface' of https://github.com/Meet-Ramjiy…
Meet-Ramjiyani-10 Mar 4, 2026
3498d10
Merge branch 'main' into enh/rnn-v2-interface
Meet-Ramjiyani-10 Mar 4, 2026
e32bb6f
Merge remote-tracking branch 'upstream/main' into enh/rnn-v2-interface
Meet-Ramjiyani-10 Mar 5, 2026
2c3d47e
[ENH] export RecurrentNetwork_v2 and RecurrentNetwork_pkg_v2 from rnn…
Meet-Ramjiyani-10 Mar 5, 2026
57721fb
Merge branch 'enh/rnn-v2-interface' of https://github.com/Meet-Ramjiy…
Meet-Ramjiyani-10 Mar 5, 2026
ddb218d
[BUG] fix info:name tag in RecurrentNetwork_pkg_v2 to match class name
Meet-Ramjiyani-10 Mar 5, 2026
e15c963
[BUG] fix info:name tag and switch to EncoderDecoderTimeSeriesDataModule
Meet-Ramjiyani-10 Mar 5, 2026
d56dab1
[ENH] update RecurrentNetwork_v2 to use BaseModel and encoder_cont keys
Meet-Ramjiyani-10 Mar 5, 2026
fc53b3e
[BUG] fix _pkg method to use renamed package class
Meet-Ramjiyani-10 Mar 5, 2026
df73720
[BUG] fix info:name tag in RecurrentNetwork_pkg_v2 to match class name
Meet-Ramjiyani-10 Mar 5, 2026
014f63d
[BUG] fix info:name tag in RecurrentNetwork_pkg_v2 to match class name
Meet-Ramjiyani-10 Mar 5, 2026
da1025b
[ENH] rename RNN v2 class to RNN and pkg to RNN_pkg_v2 following v2 c…
Meet-Ramjiyani-10 Mar 5, 2026
e0b007a
Merge branch 'main' into enh/rnn-v2-interface
Meet-Ramjiyani-10 Mar 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion pytorch_forecasting/models/rnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,12 @@

from pytorch_forecasting.models.rnn._rnn import RecurrentNetwork
from pytorch_forecasting.models.rnn._rnn_pkg import RecurrentNetwork_pkg
from pytorch_forecasting.models.rnn._rnn_pkg_v2 import RNN_pkg_v2
from pytorch_forecasting.models.rnn._rnn_v2 import RNN

__all__ = ["RecurrentNetwork", "RecurrentNetwork_pkg"]
__all__ = [
"RecurrentNetwork",
"RecurrentNetwork_pkg",
"RNN",
"RNN_pkg_v2",
]
80 changes: 80 additions & 0 deletions pytorch_forecasting/models/rnn/_rnn_pkg_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""
Package container for RNN v2 model.
"""

from pytorch_forecasting.base._base_pkg import Base_pkg


class RNN_pkg_v2(Base_pkg):
"""RNN v2 package container."""

_tags = {
"info:name": "RNN",
"info:compute": 2,
"authors": ["Meet-Ramjiyani-10"],
"capability:exogenous": True,
"capability:multivariate": True,
"capability:pred_int": True,
"capability:flexible_history_length": True,
"capability:cold_start": False,
}

@classmethod
def get_cls(cls):
"""Get model class."""
from pytorch_forecasting.models.rnn._rnn_v2 import RNN

return RNN

@classmethod
def get_datamodule_cls(cls):
"""Get the underlying DataModule class."""
from pytorch_forecasting.data.data_module import (
EncoderDecoderTimeSeriesDataModule,
)

return EncoderDecoderTimeSeriesDataModule

@classmethod
def get_test_train_params(cls):
"""
Return testing parameter settings for the trainer.

Returns
-------
params : list of dict
Parameters to create testing instances of the class.
"""
from pytorch_forecasting.metrics import MAE, SMAPE, QuantileLoss
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to use them in the params?
I think it would be good if we could chekc which type of losses this can support

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i will update get_test_train_params to include MAE, SMAPE, and QuantileLoss directly in the params to verify loss compatibility, as in _timexer_pkg_v2.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phoeenniixx done 👍


params = [
dict(loss=MAE()),
dict(
loss=SMAPE(),
cell_type="LSTM",
hidden_size=32,
rnn_layers=1,
),
dict(
loss=QuantileLoss(quantiles=[0.1, 0.5, 0.9]),
cell_type="GRU",
hidden_size=32,
rnn_layers=1,
),
dict(
loss=MAE(),
cell_type="LSTM",
hidden_size=16,
rnn_layers=2,
dropout=0.1,
),
]

default_dm_cfg = {"max_encoder_length": 8, "max_prediction_length": 2}

for param in params:
current_dm_cfg = param.get("datamodule_cfg", {})
default_dm_cfg.update(current_dm_cfg)
param["datamodule_cfg"] = default_dm_cfg

return params
173 changes: 173 additions & 0 deletions pytorch_forecasting/models/rnn/_rnn_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
"""
Recurrent Network (LSTM/GRU) model for PyTorch Forecasting v2.
---------------------------------------------------------------
"""

from typing import Any

import torch
import torch.nn as nn
from torch.optim import Optimizer

from pytorch_forecasting.metrics import QuantileLoss
from pytorch_forecasting.models.base._base_model_v2 import BaseModel


class RNN(BaseModel):
"""
Recurrent Network model for time series forecasting.

Supports LSTM and GRU cell types. Encodes the input sequence
using a recurrent layer and projects the final hidden state
to the prediction horizon.

Parameters
----------
loss : nn.Module
Loss function for training.
cell_type : str, optional
Recurrent cell type, either "LSTM" or "GRU". Default is "LSTM".
hidden_size : int, optional
Number of features in the hidden state. Default is 64.
rnn_layers : int, optional
Number of recurrent layers. Default is 2.
dropout : float, optional
Dropout rate between RNN layers. Default is 0.1.
logging_metrics : list[nn.Module], optional
Metrics to log during training. Default is None.
optimizer : str or Optimizer, optional
Optimizer to use. Default is "adam".
optimizer_params : dict, optional
Parameters for the optimizer. Default is None.
lr_scheduler : str, optional
Learning rate scheduler. Default is None.
lr_scheduler_params : dict, optional
Parameters for the scheduler. Default is None.
metadata : dict, optional
Metadata from EncoderDecoderTimeSeriesDataModule. Default is None.
"""

@classmethod
def _pkg(cls):
"""Package containing the model."""
from pytorch_forecasting.models.rnn._rnn_pkg_v2 import RNN_pkg_v2

return RNN_pkg_v2

def __init__(
self,
loss: nn.Module,
cell_type: str = "LSTM",
hidden_size: int = 64,
rnn_layers: int = 2,
dropout: float = 0.1,
logging_metrics: list[nn.Module] | None = None,
optimizer: Optimizer | str | None = "adam",
optimizer_params: dict | None = None,
lr_scheduler: str | None = None,
lr_scheduler_params: dict | None = None,
metadata: dict | None = None,
**kwargs: Any,
):
super().__init__(
loss=loss,
logging_metrics=logging_metrics,
optimizer=optimizer,
optimizer_params=optimizer_params,
lr_scheduler=lr_scheduler,
lr_scheduler_params=lr_scheduler_params,
)

assert cell_type in (
"LSTM",
"GRU",
), f"cell_type must be 'LSTM' or 'GRU', got '{cell_type}'"

self.cell_type = cell_type
self.hidden_size = hidden_size
self.rnn_layers = rnn_layers
self.dropout = dropout
self.metadata = metadata

self.save_hyperparameters(ignore=["loss", "logging_metrics", "metadata"])

self.max_encoder_length = metadata["max_encoder_length"]
self.max_prediction_length = metadata["max_prediction_length"]
self.encoder_cont = metadata["encoder_cont"]
self.encoder_cat = metadata["encoder_cat"]
self.input_dim = self.encoder_cont + self.encoder_cat

self.n_quantiles = None
if isinstance(loss, QuantileLoss):
self.n_quantiles = len(loss.quantiles)

self._init_network()

def _init_network(self):
"""Initialize the RNN network layers."""
if self.cell_type == "LSTM":
self.rnn = nn.LSTM(
input_size=max(1, self.input_dim),
hidden_size=self.hidden_size,
num_layers=self.rnn_layers,
dropout=self.dropout if self.rnn_layers > 1 else 0,
batch_first=True,
)
else:
self.rnn = nn.GRU(
input_size=max(1, self.input_dim),
hidden_size=self.hidden_size,
num_layers=self.rnn_layers,
dropout=self.dropout if self.rnn_layers > 1 else 0,
batch_first=True,
)

if self.n_quantiles is not None:
output_size = self.max_prediction_length * self.n_quantiles
else:
output_size = self.max_prediction_length

self.output_projector = nn.Linear(self.hidden_size, output_size)

def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""
Forward pass of the RNN model.

Parameters
----------
x : dict[str, torch.Tensor]
Dictionary containing input tensors.

Returns
-------
dict[str, torch.Tensor]
Dictionary containing output tensors with key "prediction".
"""
batch_size = x["encoder_cont"].shape[0]

encoder_cont = x.get(
"encoder_cont",
torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device),
)
encoder_cat = x.get(
"encoder_cat",
torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device),
)

input_data = torch.cat([encoder_cont, encoder_cat], dim=-1)

if input_data.size(-1) == 0:
input_data = torch.zeros(
batch_size, self.max_encoder_length, 1, device=self.device
)

rnn_out, _ = self.rnn(input_data)
last_hidden = rnn_out[:, -1, :]
output = self.output_projector(last_hidden)

if self.n_quantiles is not None:
output = output.reshape(-1, self.max_prediction_length, self.n_quantiles)
else:
output = output.reshape(-1, self.max_prediction_length, 1)

return {"prediction": output}
Loading