-
Notifications
You must be signed in to change notification settings - Fork 820
[ENH] Add v2 interface support for RecurrentNetwork (RNN) #2136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
da9b21a
d47c878
825b12d
8e26e9d
3498d10
e32bb6f
2c3d47e
57721fb
ddb218d
e15c963
d56dab1
fc53b3e
df73720
014f63d
da1025b
e0b007a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,139 @@ | ||
| """ | ||
| Package container for RecurrentNetwork v2 model. | ||
| """ | ||
|
|
||
| from pytorch_forecasting.base._base_pkg import Base_pkg | ||
|
|
||
|
|
||
| class RecurrentNetwork_pkg_v2(Base_pkg): | ||
| """RecurrentNetwork v2 package container.""" | ||
|
|
||
| _tags = { | ||
| "info:name": "RecurrentNetwork", | ||
| "info:compute": 2, | ||
| "authors": ["Meet-Ramjiyani-10"], | ||
| "capability:exogenous": True, | ||
| "capability:multivariate": True, | ||
| "capability:pred_int": True, | ||
| "capability:flexible_history_length": True, | ||
| "capability:cold_start": False, | ||
| } | ||
|
|
||
| @classmethod | ||
| def get_cls(cls): | ||
| """Get model class.""" | ||
| from pytorch_forecasting.models.rnn._rnn_v2 import RecurrentNetwork_v2 | ||
|
|
||
| return RecurrentNetwork_v2 | ||
|
|
||
| @classmethod | ||
| def get_datamodule_cls(cls): | ||
| """Get the underlying DataModule class.""" | ||
| from pytorch_forecasting.data._tslib_data_module import TslibDataModule | ||
|
|
||
| return TslibDataModule | ||
|
|
||
| @classmethod | ||
| def _get_test_datamodule_from(cls, trainer_kwargs): | ||
|
||
| """Create test dataloaders from trainer_kwargs.""" | ||
| from pytorch_forecasting.data._tslib_data_module import TslibDataModule | ||
| from pytorch_forecasting.tests._data_scenarios import ( | ||
| data_with_covariates_v2, | ||
| make_datasets_v2, | ||
| ) | ||
|
|
||
| data_with_covariates = data_with_covariates_v2() | ||
| data_loader_default_kwargs = dict( | ||
| target="target", | ||
| group_ids=["agency_encoded", "sku_encoded"], | ||
| add_relative_time_idx=True, | ||
| ) | ||
|
|
||
| data_loader_kwargs = trainer_kwargs.get("data_loader_kwargs", {}) | ||
| data_loader_default_kwargs.update(data_loader_kwargs) | ||
|
|
||
| datasets_info = make_datasets_v2( | ||
| data_with_covariates, **data_loader_default_kwargs | ||
| ) | ||
|
|
||
| training_dataset = datasets_info["training_dataset"] | ||
| validation_dataset = datasets_info["validation_dataset"] | ||
|
|
||
| context_length = data_loader_kwargs.get("context_length", 8) | ||
| prediction_length = data_loader_kwargs.get("prediction_length", 2) | ||
| batch_size = data_loader_kwargs.get("batch_size", 2) | ||
|
|
||
| train_datamodule = TslibDataModule( | ||
| time_series_dataset=training_dataset, | ||
| context_length=context_length, | ||
| prediction_length=prediction_length, | ||
| add_relative_time_idx=data_loader_kwargs.get("add_relative_time_idx", True), | ||
| batch_size=batch_size, | ||
| train_val_test_split=(0.8, 0.2, 0.0), | ||
| ) | ||
|
|
||
| val_datamodule = TslibDataModule( | ||
| time_series_dataset=validation_dataset, | ||
| context_length=context_length, | ||
| prediction_length=prediction_length, | ||
| add_relative_time_idx=data_loader_kwargs.get("add_relative_time_idx", True), | ||
| batch_size=batch_size, | ||
| train_val_test_split=(0.0, 1.0, 0.0), | ||
| ) | ||
|
|
||
| test_datamodule = TslibDataModule( | ||
| time_series_dataset=validation_dataset, | ||
| context_length=context_length, | ||
| prediction_length=prediction_length, | ||
| add_relative_time_idx=data_loader_kwargs.get("add_relative_time_idx", True), | ||
| batch_size=batch_size, | ||
| train_val_test_split=(0.0, 0.0, 1.0), | ||
| ) | ||
|
|
||
| train_datamodule.setup("fit") | ||
| val_datamodule.setup("fit") | ||
| test_datamodule.setup("test") | ||
|
|
||
| train_dataloader = train_datamodule.train_dataloader() | ||
| val_dataloader = val_datamodule.val_dataloader() | ||
| test_dataloader = test_datamodule.test_dataloader() | ||
|
|
||
| return { | ||
| "train": train_dataloader, | ||
| "val": val_dataloader, | ||
| "test": test_dataloader, | ||
| "data_module": train_datamodule, | ||
| } | ||
|
|
||
| @classmethod | ||
| def get_test_train_params(cls): | ||
| """ | ||
| Return testing parameter settings for the trainer. | ||
| Returns | ||
| ------- | ||
| params : list of dict | ||
| Parameters to create testing instances of the class. | ||
| """ | ||
| from pytorch_forecasting.metrics import MAE, SMAPE, QuantileLoss | ||
Meet-Ramjiyani-10 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| params = [ | ||
| {}, | ||
| dict(cell_type="LSTM", hidden_size=32, rnn_layers=1), | ||
| dict(cell_type="GRU", hidden_size=32, rnn_layers=1), | ||
| dict( | ||
| cell_type="LSTM", | ||
| hidden_size=16, | ||
| rnn_layers=2, | ||
| dropout=0.1, | ||
| ), | ||
| ] | ||
|
|
||
| default_dm_cfg = {"context_length": 8, "prediction_length": 2} | ||
|
|
||
| for param in params: | ||
| current_dm_cfg = param.get("datamodule_cfg", {}) | ||
| default_dm_cfg.update(current_dm_cfg) | ||
| param["datamodule_cfg"] = default_dm_cfg | ||
|
|
||
| return params | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,170 @@ | ||
| """ | ||
| Recurrent Network (LSTM/GRU) model for PyTorch Forecasting v2. | ||
| --------------------------------------------------------------- | ||
| """ | ||
|
|
||
| from typing import Any | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from torch.optim import Optimizer | ||
|
|
||
| from pytorch_forecasting.metrics import QuantileLoss | ||
| from pytorch_forecasting.models.base._tslib_base_model_v2 import TslibBaseModel | ||
|
|
||
|
|
||
| class RecurrentNetwork_v2(TslibBaseModel): | ||
| """ | ||
| Recurrent Network model for time series forecasting. | ||
|
|
||
| Supports LSTM and GRU cell types. Encodes the input sequence | ||
| using a recurrent layer and projects the final hidden state | ||
| to the prediction horizon. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| loss : nn.Module | ||
| Loss function for training. | ||
| cell_type : str, optional | ||
| Recurrent cell type, either "LSTM" or "GRU". Default is "LSTM". | ||
| hidden_size : int, optional | ||
| Number of features in the hidden state. Default is 64. | ||
| rnn_layers : int, optional | ||
| Number of recurrent layers. Default is 2. | ||
| dropout : float, optional | ||
| Dropout rate between RNN layers. Default is 0.1. | ||
| logging_metrics : list[nn.Module], optional | ||
| Metrics to log during training. Default is None. | ||
| optimizer : str or Optimizer, optional | ||
| Optimizer to use. Default is "adam". | ||
| optimizer_params : dict, optional | ||
| Parameters for the optimizer. Default is None. | ||
| lr_scheduler : str, optional | ||
| Learning rate scheduler. Default is None. | ||
| lr_scheduler_params : dict, optional | ||
| Parameters for the scheduler. Default is None. | ||
| metadata : dict, optional | ||
| Metadata from TslibDataModule. Default is None. | ||
| """ | ||
|
|
||
| @classmethod | ||
| def _pkg(cls): | ||
| """Package containing the model.""" | ||
| from pytorch_forecasting.models.rnn._rnn_pkg_v2 import RecurrentNetwork_pkg_v2 | ||
|
|
||
| return RecurrentNetwork_pkg_v2 | ||
|
|
||
| def __init__( | ||
| self, | ||
| loss: nn.Module, | ||
| cell_type: str = "LSTM", | ||
| hidden_size: int = 64, | ||
| rnn_layers: int = 2, | ||
| dropout: float = 0.1, | ||
| logging_metrics: list[nn.Module] | None = None, | ||
| optimizer: Optimizer | str | None = "adam", | ||
| optimizer_params: dict | None = None, | ||
| lr_scheduler: str | None = None, | ||
| lr_scheduler_params: dict | None = None, | ||
| metadata: dict | None = None, | ||
| **kwargs: Any, | ||
| ): | ||
| super().__init__( | ||
| loss=loss, | ||
| logging_metrics=logging_metrics, | ||
| optimizer=optimizer, | ||
| optimizer_params=optimizer_params, | ||
| lr_scheduler=lr_scheduler, | ||
| lr_scheduler_params=lr_scheduler_params, | ||
| metadata=metadata, | ||
| ) | ||
|
|
||
| assert cell_type in ( | ||
| "LSTM", | ||
| "GRU", | ||
| ), f"cell_type must be 'LSTM' or 'GRU', got '{cell_type}'" | ||
|
|
||
| self.cell_type = cell_type | ||
| self.hidden_size = hidden_size | ||
| self.rnn_layers = rnn_layers | ||
| self.dropout = dropout | ||
|
|
||
| self.save_hyperparameters(ignore=["loss", "logging_metrics", "metadata"]) | ||
|
|
||
| self.n_quantiles = None | ||
| if isinstance(loss, QuantileLoss): | ||
| self.n_quantiles = len(loss.quantiles) | ||
|
|
||
| self._init_network() | ||
|
|
||
| def _init_network(self): | ||
| """Initialize the RNN network layers.""" | ||
|
|
||
| input_size = self.cont_dim + self.target_dim | ||
|
|
||
| if self.cell_type == "LSTM": | ||
| self.rnn = nn.LSTM( | ||
| input_size=input_size, | ||
| hidden_size=self.hidden_size, | ||
| num_layers=self.rnn_layers, | ||
| dropout=self.dropout if self.rnn_layers > 1 else 0, | ||
| batch_first=True, | ||
| ) | ||
| else: | ||
| self.rnn = nn.GRU( | ||
| input_size=input_size, | ||
| hidden_size=self.hidden_size, | ||
| num_layers=self.rnn_layers, | ||
| dropout=self.dropout if self.rnn_layers > 1 else 0, | ||
| batch_first=True, | ||
| ) | ||
|
|
||
| if self.n_quantiles is not None: | ||
| output_size = self.prediction_length * self.n_quantiles | ||
| else: | ||
| output_size = self.prediction_length * self.target_dim | ||
|
|
||
| self.output_projector = nn.Linear(self.hidden_size, output_size) | ||
|
|
||
| def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: | ||
| """ | ||
| Forward pass of the RecurrentNetwork model. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| x : dict[str, torch.Tensor] | ||
| Dictionary containing input tensors. | ||
|
|
||
| Returns | ||
| ------- | ||
| dict[str, torch.Tensor] | ||
| Dictionary containing output tensors with key "prediction". | ||
| """ | ||
| available_features = [] | ||
|
|
||
| if "history_cont" in x and x["history_cont"].size(-1) > 0: | ||
| available_features.append(x["history_cont"]) | ||
|
|
||
| if "history_target" in x and x["history_target"].size(-1) > 0: | ||
| available_features.append(x["history_target"]) | ||
|
|
||
| if not available_features: | ||
| raise ValueError("No valid input features found in input dictionary.") | ||
|
|
||
| input_data = torch.cat(available_features, dim=-1) | ||
|
|
||
| rnn_out, _ = self.rnn(input_data) | ||
|
|
||
| last_hidden = rnn_out[:, -1, :] | ||
|
|
||
| output = self.output_projector(last_hidden) | ||
|
|
||
| if self.n_quantiles is not None: | ||
| output = output.reshape(-1, self.prediction_length, self.n_quantiles) | ||
| else: | ||
| output = output.reshape(-1, self.prediction_length, self.target_dim) | ||
|
|
||
| if "target_scale" in x and hasattr(self, "transform_output"): | ||
| output = self.transform_output(output, x["target_scale"]) | ||
|
|
||
| return {"prediction": output} |
Uh oh!
There was an error while loading. Please reload this page.