Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions pytorch_forecasting/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@
from pytorch_forecasting.layers._blocks import ResidualBlock
from pytorch_forecasting.layers._decomposition import SeriesDecomposition
from pytorch_forecasting.layers._embeddings import (
DataEmbedding,
DataEmbedding_inverted,
EnEmbedding,
FixedEmbedding,
PositionalEmbedding,
TemporalEmbedding,
TokenEmbedding,
embedding_cat_variables,
)
from pytorch_forecasting.layers._encoders import (
Expand All @@ -33,6 +37,11 @@
sLSTMLayer,
sLSTMNetwork,
)
from pytorch_forecasting.layers._reformer import (
ReformerEncoder,
ReformerEncoderLayer,
ReformerLayer,
)

__all__ = [
"FullAttention",
Expand All @@ -54,4 +63,11 @@
"RevIN",
"ResidualBlock",
"embedding_cat_variables",
"ReformerEncoder",
"ReformerEncoderLayer",
"ReformerLayer",
"DataEmbedding",
"TemporalEmbedding",
"FixedEmbedding",
"TokenEmbedding",
]
10 changes: 10 additions & 0 deletions pytorch_forecasting/layers/_embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,27 @@
"""

from pytorch_forecasting.layers._embeddings._data_embedding import (
DataEmbedding,
DataEmbedding_inverted,
)
from pytorch_forecasting.layers._embeddings._en_embedding import EnEmbedding
from pytorch_forecasting.layers._embeddings._positional_embedding import (
PositionalEmbedding,
)
from pytorch_forecasting.layers._embeddings._sub_nn import embedding_cat_variables
from pytorch_forecasting.layers._embeddings._temporal_embedding import (
FixedEmbedding,
TemporalEmbedding,
)
from pytorch_forecasting.layers._embeddings._token_embedding import TokenEmbedding

__all__ = [
"PositionalEmbedding",
"DataEmbedding",
"DataEmbedding_inverted",
"EnEmbedding",
"embedding_cat_variables",
"FixedEmbedding",
"TemporalEmbedding",
"TokenEmbedding",
]
102 changes: 99 additions & 3 deletions pytorch_forecasting/layers/_embeddings/_data_embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""
Data embedding layer for exogenous variables.
"""
"""Data embedding utilities."""

import math
from math import sqrt
Expand All @@ -10,6 +8,90 @@
import torch.nn as nn
import torch.nn.functional as F

from pytorch_forecasting.layers._embeddings._positional_embedding import (
PositionalEmbedding,
)
from pytorch_forecasting.layers._embeddings._temporal_embedding import TemporalEmbedding
from pytorch_forecasting.layers._embeddings._token_embedding import TokenEmbedding


class TimeFeatureEmbedding(nn.Module):
"""Embed numeric time features into the model dimension.

Args:
d_model (int): output embedding dimension.
embed_type (str): unused but kept for API compatibility.
freq (str): frequency code determines the expected number of input
time features (e.g., 'h' -> 1, 't' -> 5).
"""

def __init__(self, d_model, embed_type="timeF", freq="h"):
super().__init__()

freq_map = {"h": 1, "t": 5, "s": 6, "m": 1, "a": 1, "w": 2, "d": 3, "b": 3}
d_inp = freq_map[freq]
self.embed = nn.Linear(d_inp, d_model, bias=False)

def forward(self, x):
"""Apply a linear projection to the input time features.

Args:
x (torch.Tensor): tensor of numeric time features with last
dimension equal to the number of features implied by
`freq`.

Returns:
torch.Tensor: projected tensor with last dimension `d_model`.
"""
return self.embed(x)


class DataEmbedding(nn.Module):
def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1):
"""Compose token, positional and temporal embeddings.

Args:
c_in (int): number of input features/channels for the token
embedding.
d_model (int): model/embedding dimensionality.
embed_type (str): type of temporal embedding ('fixed', 'learned',
or 'timeF').
freq (str): frequency code passed to temporal embedding.
dropout (float): dropout probability applied to summed embedding.
"""
super().__init__()

self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
self.position_embedding = PositionalEmbedding(d_model=d_model)

self.temporal_embedding = (
TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
if embed_type != "timeF"
else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
)
self.dropout = nn.Dropout(p=dropout)

def forward(self, x, x_mark):
"""Compute the full data embedding used by sequence models.

Args:
x (torch.Tensor): token/value input of shape `[batch, seq_len, c_in]`.
x_mark (torch.Tensor or None): temporal marker tensor used to
compute temporal embeddings (shape `[batch, seq_len, ...]`).

Returns:
torch.Tensor: embedded tensor of shape `[batch, seq_len, d_model]`.
"""
if x_mark is None:
x = self.value_embedding(x) + self.position_embedding(x)
else:
x = (
self.value_embedding(x)
+ self.temporal_embedding(x_mark)
+ self.position_embedding(x)
)
return self.dropout(x)


class DataEmbedding_inverted(nn.Module):
"""
Expand All @@ -35,6 +117,20 @@ def __init__(self, c_in, d_model, dropout=0.1):
self.dropout = nn.Dropout(p=dropout)

def forward(self, x, x_mark):
"""Embed inputs where time is the last dimension.

This variant expects `x` shaped `[batch, seq_len, c_in]` and performs a
transpose before applying a linear projection. If `x_mark` is provided
it is concatenated along the feature dimension before projection.

Args:
x (torch.Tensor): input tensor of shape `[batch, seq_len, c_in]`.
x_mark (torch.Tensor or None): optional temporal features with the
same leading shape as `x`.

Returns:
torch.Tensor: embedded tensor with dropout applied.
"""
x = x.permute(0, 2, 1)
# x: [Batch Variate Time]
if x_mark is None:
Expand Down
98 changes: 98 additions & 0 deletions pytorch_forecasting/layers/_embeddings/_temporal_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""Temporal embeddings for time features."""

import math

import torch
import torch.nn as nn


class FixedEmbedding(nn.Module):
"""Fixed positional embedding using sinusoidal encodings.

Args:
c_in (int): number of discrete positions (e.g., months, hours).
d_model (int): embedding dimensionality.

"""

def __init__(self, c_in, d_model):
super().__init__()

w = torch.zeros(c_in, d_model).float()
w.require_grad = False

position = torch.arange(0, c_in).float().unsqueeze(1)
div_term = (
torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
).exp()

w[:, 0::2] = torch.sin(position * div_term)
w[:, 1::2] = torch.cos(position * div_term)

self.emb = nn.Embedding(c_in, d_model)
self.emb.weight = nn.Parameter(w, requires_grad=False)

def forward(self, x):
"""Return the fixed embeddings for the input indices.

Args:
x (torch.Tensor): integer tensor of indices (any shape).

Returns:
torch.Tensor: embedding vectors with the same leading shape as
`x` and last dimension `d_model`.
"""
return self.emb(x).detach()


class TemporalEmbedding(nn.Module):
"""Compose embeddings for temporal components.

Args:
d_model (int): embedding dimensionality.
embed_type (str): 'fixed' to use FixedEmbedding; otherwise uses
`nn.Embedding` (learned embeddings).
freq (str): frequency code; if 't' (minute frequency) a minute
embedding is included.
"""

def __init__(self, d_model, embed_type="fixed", freq="h"):
super().__init__()

minute_size = 4
hour_size = 24
weekday_size = 7
day_size = 32
month_size = 13

Embed = FixedEmbedding if embed_type == "fixed" else nn.Embedding
if freq == "t":
self.minute_embed = Embed(minute_size, d_model)
self.hour_embed = Embed(hour_size, d_model)
self.weekday_embed = Embed(weekday_size, d_model)
self.day_embed = Embed(day_size, d_model)
self.month_embed = Embed(month_size, d_model)

def forward(self, x):
"""Return the sum of component temporal embeddings.

Args:
x (torch.Tensor): integer tensor of shape
`[batch, seq_len, num_time_features]` where the last
dimension contains `[month, day, weekday, hour, minute]`
(minute optional depending on `freq`).

Returns:
torch.Tensor: summed temporal embedding of shape
`[batch, seq_len, d_model]`.
"""
x = x.long()
minute_x = (
self.minute_embed(x[:, :, 4]) if hasattr(self, "minute_embed") else 0.0
)
hour_x = self.hour_embed(x[:, :, 3])
weekday_x = self.weekday_embed(x[:, :, 2])
day_x = self.day_embed(x[:, :, 1])
month_x = self.month_embed(x[:, :, 0])

return hour_x + weekday_x + day_x + month_x + minute_x
42 changes: 42 additions & 0 deletions pytorch_forecasting/layers/_embeddings/_token_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Token embedding module."""

import torch
import torch.nn as nn


class TokenEmbedding(nn.Module):
"""Token embedding using a 1D convolution.

Args:
c_in (int): number of input channels/features.
d_model (int): output embedding dimension.
"""

def __init__(self, c_in, d_model):
super().__init__()
padding = 1 if torch.__version__ >= "1.5.0" else 2
self.tokenConv = nn.Conv1d(
in_channels=c_in,
out_channels=d_model,
kernel_size=3,
padding=padding,
padding_mode="circular",
bias=False,
)
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(
m.weight, mode="fan_in", nonlinearity="leaky_relu"
)

def forward(self, x):
"""Apply convolutional token embedding.

Args:
x (torch.Tensor): input tensor of shape `[batch, seq_len, c_in]`.

Returns:
torch.Tensor: embedded tensor of shape `[batch, seq_len, d_model]`.
"""
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
return x
19 changes: 19 additions & 0 deletions pytorch_forecasting/layers/_reformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Reformer-related layer exports.

This package exposes encoder building blocks and a thin wrapper around the
LSH self-attention implementation used by Reformer-style models.
"""

from pytorch_forecasting.layers._reformer._encoder import (
ReformerEncoder,
ReformerEncoderLayer,
)
from pytorch_forecasting.layers._reformer._lsh_self_attention import LSHSelfAttention
from pytorch_forecasting.layers._reformer._reformer_layer import ReformerLayer

__all__ = [
"ReformerEncoderLayer",
"ReformerEncoder",
"ReformerLayer",
"LSHSelfAttention",
]
Loading
Loading