Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pytorch_forecasting/layers/_blocks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pytorch_forecasting.layers._blocks._ie_block import _IEBlock
from pytorch_forecasting.layers._blocks._residual_block_dsipts import ResidualBlock

__all__ = ["ResidualBlock"]
__all__ = ["ResidualBlock", "_IEBlock"]
74 changes: 74 additions & 0 deletions pytorch_forecasting/layers/_blocks/_ie_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import torch
import torch.nn as nn


class IEBlock(nn.Module):
"""
Information Exchange block used by LightTS.

Applies spatial projection, channel mixing, and an output projection
to exchange information across time-series channels.
"""

def __init__(
self, input_dim: int, hidden_dim: int, output_dim: int, num_nodes: int
) -> None:
"""
Initialize the IEBlock.

Parameters
----------
input_dim : int
Input feature dimension.
hidden_dim : int
Hidden projection size.
output_dim : int
Output feature dimension.
num_nodes : int
Number of channels mixed by the block.
"""
super().__init__()

if input_dim <= 0:
raise ValueError("input_dim must be positive.")
if hidden_dim <= 0:
raise ValueError("hidden_dim must be positive.")
if output_dim <= 0:
raise ValueError("output_dim must be positive.")
if num_nodes <= 0:
raise ValueError("num_nodes must be positive.")

reduced_dim = max(1, hidden_dim // 4)

self.spatial_proj = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.LeakyReLU(),
nn.Linear(hidden_dim, reduced_dim),
)

self.channel_proj = nn.Linear(num_nodes, num_nodes)
nn.init.eye_(self.channel_proj.weight)
if self.channel_proj.bias is not None:
nn.init.zeros_(self.channel_proj.bias)

self.output_proj = nn.Linear(reduced_dim, output_dim)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the IEBlock.

Parameters
----------
x : torch.Tensor
Input tensor of shape ``(batch_size, input_dim, num_nodes)``.

Returns
-------
torch.Tensor
Output tensor of shape ``(batch_size, output_dim, num_nodes)``.
"""

x = self.spatial_proj(x.permute(0, 2, 1))
x = x.permute(0, 2, 1) + self.channel_proj(x.permute(0, 2, 1))
x = self.output_proj(x.permute(0, 2, 1))
return x.permute(0, 2, 1)
8 changes: 8 additions & 0 deletions pytorch_forecasting/models/lightts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""
LightTS model for time series forecasting.
"""

from pytorch_forecasting.models.lightts._lightts_pkg_v2 import LightTS_pkg_v2
from pytorch_forecasting.models.lightts._lightts_v2 import LightTS

__all__ = ["LightTS", "LightTS_pkg_v2"]
94 changes: 94 additions & 0 deletions pytorch_forecasting/models/lightts/_lightts_pkg_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""
Package container for LightTS model.
"""

from pytorch_forecasting.base._base_pkg import Base_pkg


class LightTS_pkg_v2(Base_pkg):
"""
Package container describing the LightTS model.

This class registers metadata, links the model implementation,
and provides helper utilities used during testing.
"""

_tags = {
"info:name": "LightTS",
"info:compute": 2,
"authors": ["Sylver-Icy"],
"capability:exogenous": True,
"capability:multivariate": True,
"capability:pred_int": True,
"capability:flexible_history_length": True,
"capability:cold_start": False,
}

@classmethod
def get_cls(cls):
"""
Return the LightTS model class.

Returns
-------
type
The LightTS model implementation.
"""
from pytorch_forecasting.models.lightts._lightts_v2 import LightTS

return LightTS

@classmethod
def get_datamodule_cls(cls):
"""
Return the datamodule used for LightTS training and evaluation.

Returns
-------
type
The TslibDataModule class.
"""
from pytorch_forecasting.data._tslib_data_module import TslibDataModule

return TslibDataModule

@classmethod
def get_test_train_params(cls):
"""
Provide parameter configurations used for automated model tests.

Returns
-------
list of dict
Different model parameter combinations used during testing.
"""
from pytorch_forecasting.metrics import MAE, SMAPE, QuantileLoss

params = [
{},
dict(
d_model=128,
chunk_size=4,
logging_metrics=[SMAPE()],
loss=MAE(),
),
dict(
d_model=64,
chunk_size=2,
dropout=0.2,
),
dict(
d_model=96,
chunk_size=8,
loss=QuantileLoss(quantiles=[0.1, 0.5, 0.9]),
),
]

default_dm_cfg = {"context_length": 12, "prediction_length": 4}

for param in params:
current_dm_cfg = param.get("datamodule_cfg", {})
default_dm_cfg.update(current_dm_cfg)
param["datamodule_cfg"] = default_dm_cfg

return params
Loading
Loading