-
Notifications
You must be signed in to change notification settings - Fork 820
[ENH] Add LightTS model implementation (v2) #2135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Sylver-Icy
wants to merge
14
commits into
sktime:main
Choose a base branch
from
Sylver-Icy:add-lightts
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+648
−1
Open
Changes from 12 commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
616c180
[MNT] Add LightTS model implementation
Sylver-Icy b19c900
[MNT] Implement LightTS package container
Sylver-Icy 0d87c2a
[MNT] Add LightTS model initialization in the package
Sylver-Icy 4452e76
[MNT] Add unit tests for LightTS model functionality
Sylver-Icy b0b9821
[MNT] Refactor LightTS model by removing dropout layer and adjusting …
Sylver-Icy ea8d34b
Merge branch 'main' into add-lightts
Sylver-Icy 2f90c4e
[MNT] Apply ruff formatting fixes
Sylver-Icy a942cab
[MNT] Enhance documentation for LightTS model and IEBlock class
Sylver-Icy 571d426
[MNT] Remove unused method
Sylver-Icy 9759c33
Merge branch 'main' into add-lightts
Sylver-Icy 53d50e0
[MNT] Move _IEBlock to layers module
Sylver-Icy 2fc7404
[MNT] Rename _IEBlock to IEBlock
Sylver-Icy 2b37b89
[MNT] remove output squeeze
Sylver-Icy 10c67d0
[MNT] update authors list in LightTS_pkg_v2
Sylver-Icy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| from pytorch_forecasting.layers._blocks._ie_block import _IEBlock | ||
| from pytorch_forecasting.layers._blocks._residual_block_dsipts import ResidualBlock | ||
|
|
||
| __all__ = ["ResidualBlock"] | ||
| __all__ = ["ResidualBlock", "_IEBlock"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
|
|
||
| class IEBlock(nn.Module): | ||
| """ | ||
| Information Exchange block used by LightTS. | ||
|
|
||
| Applies spatial projection, channel mixing, and an output projection | ||
| to exchange information across time-series channels. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, input_dim: int, hidden_dim: int, output_dim: int, num_nodes: int | ||
| ) -> None: | ||
| """ | ||
| Initialize the IEBlock. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| input_dim : int | ||
| Input feature dimension. | ||
| hidden_dim : int | ||
| Hidden projection size. | ||
| output_dim : int | ||
| Output feature dimension. | ||
| num_nodes : int | ||
| Number of channels mixed by the block. | ||
| """ | ||
| super().__init__() | ||
|
|
||
| if input_dim <= 0: | ||
| raise ValueError("input_dim must be positive.") | ||
| if hidden_dim <= 0: | ||
| raise ValueError("hidden_dim must be positive.") | ||
| if output_dim <= 0: | ||
| raise ValueError("output_dim must be positive.") | ||
| if num_nodes <= 0: | ||
| raise ValueError("num_nodes must be positive.") | ||
|
|
||
| reduced_dim = max(1, hidden_dim // 4) | ||
|
|
||
| self.spatial_proj = nn.Sequential( | ||
| nn.Linear(input_dim, hidden_dim), | ||
| nn.LeakyReLU(), | ||
| nn.Linear(hidden_dim, reduced_dim), | ||
| ) | ||
|
|
||
| self.channel_proj = nn.Linear(num_nodes, num_nodes) | ||
| nn.init.eye_(self.channel_proj.weight) | ||
| if self.channel_proj.bias is not None: | ||
| nn.init.zeros_(self.channel_proj.bias) | ||
|
|
||
| self.output_proj = nn.Linear(reduced_dim, output_dim) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Forward pass of the IEBlock. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| x : torch.Tensor | ||
| Input tensor of shape ``(batch_size, input_dim, num_nodes)``. | ||
|
|
||
| Returns | ||
| ------- | ||
| torch.Tensor | ||
| Output tensor of shape ``(batch_size, output_dim, num_nodes)``. | ||
| """ | ||
|
|
||
| x = self.spatial_proj(x.permute(0, 2, 1)) | ||
| x = x.permute(0, 2, 1) + self.channel_proj(x.permute(0, 2, 1)) | ||
| x = self.output_proj(x.permute(0, 2, 1)) | ||
| return x.permute(0, 2, 1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| """ | ||
| LightTS model for time series forecasting. | ||
| """ | ||
|
|
||
| from pytorch_forecasting.models.lightts._lightts_pkg_v2 import LightTS_pkg_v2 | ||
| from pytorch_forecasting.models.lightts._lightts_v2 import LightTS | ||
|
|
||
| __all__ = ["LightTS", "LightTS_pkg_v2"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
| """ | ||
| Package container for LightTS model. | ||
| """ | ||
|
|
||
| from pytorch_forecasting.base._base_pkg import Base_pkg | ||
|
|
||
|
|
||
| class LightTS_pkg_v2(Base_pkg): | ||
| """ | ||
| Package container describing the LightTS model. | ||
|
|
||
| This class registers metadata, links the model implementation, | ||
| and provides helper utilities used during testing. | ||
| """ | ||
|
|
||
| _tags = { | ||
| "info:name": "LightTS", | ||
| "info:compute": 2, | ||
| "authors": ["Sylver-Icy"], | ||
| "capability:exogenous": True, | ||
| "capability:multivariate": True, | ||
| "capability:pred_int": True, | ||
| "capability:flexible_history_length": True, | ||
| "capability:cold_start": False, | ||
| } | ||
|
|
||
| @classmethod | ||
| def get_cls(cls): | ||
| """ | ||
| Return the LightTS model class. | ||
|
|
||
| Returns | ||
| ------- | ||
| type | ||
| The LightTS model implementation. | ||
| """ | ||
| from pytorch_forecasting.models.lightts._lightts_v2 import LightTS | ||
|
|
||
| return LightTS | ||
|
|
||
| @classmethod | ||
| def get_datamodule_cls(cls): | ||
| """ | ||
| Return the datamodule used for LightTS training and evaluation. | ||
|
|
||
| Returns | ||
| ------- | ||
| type | ||
| The TslibDataModule class. | ||
| """ | ||
| from pytorch_forecasting.data._tslib_data_module import TslibDataModule | ||
|
|
||
| return TslibDataModule | ||
|
|
||
| @classmethod | ||
| def get_test_train_params(cls): | ||
| """ | ||
| Provide parameter configurations used for automated model tests. | ||
|
|
||
| Returns | ||
| ------- | ||
| list of dict | ||
| Different model parameter combinations used during testing. | ||
| """ | ||
| from pytorch_forecasting.metrics import MAE, SMAPE, QuantileLoss | ||
|
|
||
| params = [ | ||
| {}, | ||
| dict( | ||
| d_model=128, | ||
| chunk_size=4, | ||
| logging_metrics=[SMAPE()], | ||
| loss=MAE(), | ||
| ), | ||
| dict( | ||
| d_model=64, | ||
| chunk_size=2, | ||
| dropout=0.2, | ||
| ), | ||
| dict( | ||
| d_model=96, | ||
| chunk_size=8, | ||
| loss=QuantileLoss(quantiles=[0.1, 0.5, 0.9]), | ||
| ), | ||
| ] | ||
|
|
||
| default_dm_cfg = {"context_length": 12, "prediction_length": 4} | ||
|
|
||
| for param in params: | ||
| current_dm_cfg = param.get("datamodule_cfg", {}) | ||
| default_dm_cfg.update(current_dm_cfg) | ||
| param["datamodule_cfg"] = default_dm_cfg | ||
|
|
||
| return params | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.