Skip to content

[ENH] Add v2 interface support for RecurrentNetwork (RNN)#2136

Open
Meet-Ramjiyani-10 wants to merge 16 commits intosktime:mainfrom
Meet-Ramjiyani-10:enh/rnn-v2-interface
Open

[ENH] Add v2 interface support for RecurrentNetwork (RNN)#2136
Meet-Ramjiyani-10 wants to merge 16 commits intosktime:mainfrom
Meet-Ramjiyani-10:enh/rnn-v2-interface

Conversation

@Meet-Ramjiyani-10
Copy link
Contributor

@Meet-Ramjiyani-10 Meet-Ramjiyani-10 commented Mar 4, 2026

Reference Issues/PRs

Closes #2128. Part of #1992.

What does this implement/fix? Explain your changes.

Adds v2 interface support for the RecurrentNetwork (RNN) model.

  • Added _rnn_v2.py implementing RecurrentNetwork_v2 using TslibBaseModel
  • Added _rnn_pkg_v2.py with tags, datamodule, and test parameters
  • Supports both LSTM and GRU cell types
  • Supports QuantileLoss for probabilistic forecasting

I followed the v2 pattern established by DLinear and SAMformer and have made no changes to existing v1 implementation.

What should a reviewer concentrate their feedback on?

  • forward pass input preparation
  • Whether the output shape is consistent with other v2 models
  • Tags in the pkg file

Did you add any tests for the change?

Test parameters added in get_test_train_params covering LSTM, GRU, single layer, and multi-layer configurations.

Any other comments?

Happy to iterate based on feedback.

PR checklist

  • [✓ ] The PR title starts with either [ENH], [MNT], [DOC], or [BUG]. [BUG] - bugfix, [MNT] - CI, test framework, [ENH] - adding or improving code, [DOC] - writing or improving documentation or docstrings.

  • [ ✓] Used pre-commit hooks when committing to ensure that code is compliant with hooks. Install hooks with pre-commit install.
    To run hooks independent of commit, execute pre-commit run --all-files

@codecov
Copy link

codecov bot commented Mar 4, 2026

Codecov Report

❌ Patch coverage is 98.66667% with 1 line in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (main@cf46d07). Learn more about missing BASE report.

Files with missing lines Patch % Lines
pytorch_forecasting/models/rnn/_rnn_v2.py 98.03% 1 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main    #2136   +/-   ##
=======================================
  Coverage        ?   86.72%           
=======================================
  Files           ?      167           
  Lines           ?     9806           
  Branches        ?        0           
=======================================
  Hits            ?     8504           
  Misses          ?     1302           
  Partials        ?        0           
Flag Coverage Δ
cpu 86.72% <98.66%> (?)
pytest 86.72% <98.66%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Member

@phoeenniixx phoeenniixx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!
Added some comments

return TslibDataModule

@classmethod
def _get_test_datamodule_from(cls, trainer_kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think you need this method anymore? Please have a look at other pkgs (like timexer) for more info

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for pointing. After checking _timexer_pkg_v2.py. this method is not needed . Will remove it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phoeenniixx done 👍

"""Get the underlying DataModule class."""
from pytorch_forecasting.data._tslib_data_module import TslibDataModule

return TslibDataModule
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be EncoderDecoderDataModule?

Copy link
Contributor Author

@Meet-Ramjiyani-10 Meet-Ramjiyani-10 Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phoeenniixx After checking _timexer_pkg_v2., it also uses TslibDataModule in get_datamodule_cls. Should I still switch to EncoderDecoderDataModule, or is TslibDataModule correct here?

Copy link
Contributor

@PranavBhatP PranavBhatP Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please switch to EncoderDecoderDataModule for models with v1 native implementations. The rule of thumb is to stick to TslibDataModule only for direct migrations of models from thuml(tslib).

Copy link
Contributor Author

@Meet-Ramjiyani-10 Meet-Ramjiyani-10 Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@PranavBhatP
thanks for clarifying!. But it seems EncoderDecoderDataModule doesn't exist in the codebase yet. Should I keep TslibDataModule for now, or is there a different existing class I should use?

Copy link
Contributor

@PranavBhatP PranavBhatP Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EncoderDecoderTimeseriesDataModule is the exact name of the D2 data module I am referring, it is interchangeably referred to as EncoderDecoderDataModule (u get what i mean?).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes @PranavBhatP got it . I was searching for EncoderDecoderDataModule as a single class and couldn't find it , making the changes.. EncoderDecoderTimeSeriesDataModule

params : list of dict
Parameters to create testing instances of the class.
"""
from pytorch_forecasting.metrics import MAE, SMAPE, QuantileLoss
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to use them in the params?
I think it would be good if we could chekc which type of losses this can support

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i will update get_test_train_params to include MAE, SMAPE, and QuantileLoss directly in the params to verify loss compatibility, as in _timexer_pkg_v2.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phoeenniixx done 👍

@phoeenniixx phoeenniixx added module:models enhancement New feature or request labels Mar 4, 2026
@Meet-Ramjiyani-10
Copy link
Contributor Author

@phoeenniixx and @PranavBhatP
I’ve updated the RNN v2 package so that info:name is "RecurrentNetwork_v2", which fixes the first part of test_pkg_linkage. However, I’m still seeing failures from the second assertion:

assert object_pkg.name == object_class.name + "_pkg_v2", msg
For RNN v2 we currently have:

object_class.name == "RecurrentNetwork_v2"
object_pkg.name == "RecurrentNetwork_v2_pkg"
So the test expects "RecurrentNetwork_v2_pkg_v2", which doesn’t match the current naming. For the other v2 models (e.g. TFT, DLinear, Samformer), the class name stays the same and only the package class gets the _pkg_v2 suffix, so this rule fits them but not RNN v2.

Could you please advise what naming you’d prefer here? I see two options:

Rename the model class back to RecurrentNetwork and keep the package as RecurrentNetwork_pkg_v2 (matching the convention used for the other v2 models), or
Keep RecurrentNetwork_v2 / RecurrentNetwork_v2_pkg and relax/special‑case the test for this model.
Once I know which direction you’d like, I can adjust the code and tests accordingly.

@phoeenniixx
Copy link
Member

Yes this is a known issue, see #2080 by @PalakB09. You can:

The decision is depends on you
FYI @PranavBhatP @fkiraly

@Meet-Ramjiyani-10
Copy link
Contributor Author

okay, for now i will rename the model class to RNN

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request module:models

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[ENH] Add v2 interface support for RecurrentNetwork (RNN)

3 participants