diff --git a/pytorch_forecasting/layers/__init__.py b/pytorch_forecasting/layers/__init__.py index ffaaec653..2468225d3 100644 --- a/pytorch_forecasting/layers/__init__.py +++ b/pytorch_forecasting/layers/__init__.py @@ -10,9 +10,13 @@ from pytorch_forecasting.layers._blocks import ResidualBlock from pytorch_forecasting.layers._decomposition import SeriesDecomposition from pytorch_forecasting.layers._embeddings import ( + DataEmbedding, DataEmbedding_inverted, EnEmbedding, + FixedEmbedding, PositionalEmbedding, + TemporalEmbedding, + TokenEmbedding, embedding_cat_variables, ) from pytorch_forecasting.layers._encoders import ( @@ -33,6 +37,11 @@ sLSTMLayer, sLSTMNetwork, ) +from pytorch_forecasting.layers._reformer import ( + ReformerEncoder, + ReformerEncoderLayer, + ReformerLayer, +) __all__ = [ "FullAttention", @@ -54,4 +63,11 @@ "RevIN", "ResidualBlock", "embedding_cat_variables", + "ReformerEncoder", + "ReformerEncoderLayer", + "ReformerLayer", + "DataEmbedding", + "TemporalEmbedding", + "FixedEmbedding", + "TokenEmbedding", ] diff --git a/pytorch_forecasting/layers/_embeddings/__init__.py b/pytorch_forecasting/layers/_embeddings/__init__.py index 7e1977bc9..94e7c44d9 100644 --- a/pytorch_forecasting/layers/_embeddings/__init__.py +++ b/pytorch_forecasting/layers/_embeddings/__init__.py @@ -3,6 +3,7 @@ """ from pytorch_forecasting.layers._embeddings._data_embedding import ( + DataEmbedding, DataEmbedding_inverted, ) from pytorch_forecasting.layers._embeddings._en_embedding import EnEmbedding @@ -10,10 +11,19 @@ PositionalEmbedding, ) from pytorch_forecasting.layers._embeddings._sub_nn import embedding_cat_variables +from pytorch_forecasting.layers._embeddings._temporal_embedding import ( + FixedEmbedding, + TemporalEmbedding, +) +from pytorch_forecasting.layers._embeddings._token_embedding import TokenEmbedding __all__ = [ "PositionalEmbedding", + "DataEmbedding", "DataEmbedding_inverted", "EnEmbedding", "embedding_cat_variables", + "FixedEmbedding", + "TemporalEmbedding", + "TokenEmbedding", ] diff --git a/pytorch_forecasting/layers/_embeddings/_data_embedding.py b/pytorch_forecasting/layers/_embeddings/_data_embedding.py index cf886ff33..dc898a087 100644 --- a/pytorch_forecasting/layers/_embeddings/_data_embedding.py +++ b/pytorch_forecasting/layers/_embeddings/_data_embedding.py @@ -1,6 +1,4 @@ -""" -Data embedding layer for exogenous variables. -""" +"""Data embedding utilities.""" import math from math import sqrt @@ -10,6 +8,90 @@ import torch.nn as nn import torch.nn.functional as F +from pytorch_forecasting.layers._embeddings._positional_embedding import ( + PositionalEmbedding, +) +from pytorch_forecasting.layers._embeddings._temporal_embedding import TemporalEmbedding +from pytorch_forecasting.layers._embeddings._token_embedding import TokenEmbedding + + +class TimeFeatureEmbedding(nn.Module): + """Embed numeric time features into the model dimension. + + Args: + d_model (int): output embedding dimension. + embed_type (str): unused but kept for API compatibility. + freq (str): frequency code determines the expected number of input + time features (e.g., 'h' -> 1, 't' -> 5). + """ + + def __init__(self, d_model, embed_type="timeF", freq="h"): + super().__init__() + + freq_map = {"h": 1, "t": 5, "s": 6, "m": 1, "a": 1, "w": 2, "d": 3, "b": 3} + d_inp = freq_map[freq] + self.embed = nn.Linear(d_inp, d_model, bias=False) + + def forward(self, x): + """Apply a linear projection to the input time features. + + Args: + x (torch.Tensor): tensor of numeric time features with last + dimension equal to the number of features implied by + `freq`. + + Returns: + torch.Tensor: projected tensor with last dimension `d_model`. + """ + return self.embed(x) + + +class DataEmbedding(nn.Module): + def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1): + """Compose token, positional and temporal embeddings. + + Args: + c_in (int): number of input features/channels for the token + embedding. + d_model (int): model/embedding dimensionality. + embed_type (str): type of temporal embedding ('fixed', 'learned', + or 'timeF'). + freq (str): frequency code passed to temporal embedding. + dropout (float): dropout probability applied to summed embedding. + """ + super().__init__() + + self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) + self.position_embedding = PositionalEmbedding(d_model=d_model) + + self.temporal_embedding = ( + TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) + if embed_type != "timeF" + else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) + ) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_mark): + """Compute the full data embedding used by sequence models. + + Args: + x (torch.Tensor): token/value input of shape `[batch, seq_len, c_in]`. + x_mark (torch.Tensor or None): temporal marker tensor used to + compute temporal embeddings (shape `[batch, seq_len, ...]`). + + Returns: + torch.Tensor: embedded tensor of shape `[batch, seq_len, d_model]`. + """ + if x_mark is None: + x = self.value_embedding(x) + self.position_embedding(x) + else: + x = ( + self.value_embedding(x) + + self.temporal_embedding(x_mark) + + self.position_embedding(x) + ) + return self.dropout(x) + class DataEmbedding_inverted(nn.Module): """ @@ -35,6 +117,20 @@ def __init__(self, c_in, d_model, dropout=0.1): self.dropout = nn.Dropout(p=dropout) def forward(self, x, x_mark): + """Embed inputs where time is the last dimension. + + This variant expects `x` shaped `[batch, seq_len, c_in]` and performs a + transpose before applying a linear projection. If `x_mark` is provided + it is concatenated along the feature dimension before projection. + + Args: + x (torch.Tensor): input tensor of shape `[batch, seq_len, c_in]`. + x_mark (torch.Tensor or None): optional temporal features with the + same leading shape as `x`. + + Returns: + torch.Tensor: embedded tensor with dropout applied. + """ x = x.permute(0, 2, 1) # x: [Batch Variate Time] if x_mark is None: diff --git a/pytorch_forecasting/layers/_embeddings/_temporal_embedding.py b/pytorch_forecasting/layers/_embeddings/_temporal_embedding.py new file mode 100644 index 000000000..1f56aca53 --- /dev/null +++ b/pytorch_forecasting/layers/_embeddings/_temporal_embedding.py @@ -0,0 +1,98 @@ +"""Temporal embeddings for time features.""" + +import math + +import torch +import torch.nn as nn + + +class FixedEmbedding(nn.Module): + """Fixed positional embedding using sinusoidal encodings. + + Args: + c_in (int): number of discrete positions (e.g., months, hours). + d_model (int): embedding dimensionality. + + """ + + def __init__(self, c_in, d_model): + super().__init__() + + w = torch.zeros(c_in, d_model).float() + w.require_grad = False + + position = torch.arange(0, c_in).float().unsqueeze(1) + div_term = ( + torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) + ).exp() + + w[:, 0::2] = torch.sin(position * div_term) + w[:, 1::2] = torch.cos(position * div_term) + + self.emb = nn.Embedding(c_in, d_model) + self.emb.weight = nn.Parameter(w, requires_grad=False) + + def forward(self, x): + """Return the fixed embeddings for the input indices. + + Args: + x (torch.Tensor): integer tensor of indices (any shape). + + Returns: + torch.Tensor: embedding vectors with the same leading shape as + `x` and last dimension `d_model`. + """ + return self.emb(x).detach() + + +class TemporalEmbedding(nn.Module): + """Compose embeddings for temporal components. + + Args: + d_model (int): embedding dimensionality. + embed_type (str): 'fixed' to use FixedEmbedding; otherwise uses + `nn.Embedding` (learned embeddings). + freq (str): frequency code; if 't' (minute frequency) a minute + embedding is included. + """ + + def __init__(self, d_model, embed_type="fixed", freq="h"): + super().__init__() + + minute_size = 4 + hour_size = 24 + weekday_size = 7 + day_size = 32 + month_size = 13 + + Embed = FixedEmbedding if embed_type == "fixed" else nn.Embedding + if freq == "t": + self.minute_embed = Embed(minute_size, d_model) + self.hour_embed = Embed(hour_size, d_model) + self.weekday_embed = Embed(weekday_size, d_model) + self.day_embed = Embed(day_size, d_model) + self.month_embed = Embed(month_size, d_model) + + def forward(self, x): + """Return the sum of component temporal embeddings. + + Args: + x (torch.Tensor): integer tensor of shape + `[batch, seq_len, num_time_features]` where the last + dimension contains `[month, day, weekday, hour, minute]` + (minute optional depending on `freq`). + + Returns: + torch.Tensor: summed temporal embedding of shape + `[batch, seq_len, d_model]`. + """ + x = x.long() + minute_x = ( + self.minute_embed(x[:, :, 4]) if hasattr(self, "minute_embed") else 0.0 + ) + hour_x = self.hour_embed(x[:, :, 3]) + weekday_x = self.weekday_embed(x[:, :, 2]) + day_x = self.day_embed(x[:, :, 1]) + month_x = self.month_embed(x[:, :, 0]) + + return hour_x + weekday_x + day_x + month_x + minute_x diff --git a/pytorch_forecasting/layers/_embeddings/_token_embedding.py b/pytorch_forecasting/layers/_embeddings/_token_embedding.py new file mode 100644 index 000000000..45d897f33 --- /dev/null +++ b/pytorch_forecasting/layers/_embeddings/_token_embedding.py @@ -0,0 +1,42 @@ +"""Token embedding module.""" + +import torch +import torch.nn as nn + + +class TokenEmbedding(nn.Module): + """Token embedding using a 1D convolution. + + Args: + c_in (int): number of input channels/features. + d_model (int): output embedding dimension. + """ + + def __init__(self, c_in, d_model): + super().__init__() + padding = 1 if torch.__version__ >= "1.5.0" else 2 + self.tokenConv = nn.Conv1d( + in_channels=c_in, + out_channels=d_model, + kernel_size=3, + padding=padding, + padding_mode="circular", + bias=False, + ) + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_( + m.weight, mode="fan_in", nonlinearity="leaky_relu" + ) + + def forward(self, x): + """Apply convolutional token embedding. + + Args: + x (torch.Tensor): input tensor of shape `[batch, seq_len, c_in]`. + + Returns: + torch.Tensor: embedded tensor of shape `[batch, seq_len, d_model]`. + """ + x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) + return x diff --git a/pytorch_forecasting/layers/_reformer/__init__.py b/pytorch_forecasting/layers/_reformer/__init__.py new file mode 100644 index 000000000..edee65a7d --- /dev/null +++ b/pytorch_forecasting/layers/_reformer/__init__.py @@ -0,0 +1,19 @@ +"""Reformer-related layer exports. + +This package exposes encoder building blocks and a thin wrapper around the +LSH self-attention implementation used by Reformer-style models. +""" + +from pytorch_forecasting.layers._reformer._encoder import ( + ReformerEncoder, + ReformerEncoderLayer, +) +from pytorch_forecasting.layers._reformer._lsh_self_attention import LSHSelfAttention +from pytorch_forecasting.layers._reformer._reformer_layer import ReformerLayer + +__all__ = [ + "ReformerEncoderLayer", + "ReformerEncoder", + "ReformerLayer", + "LSHSelfAttention", +] diff --git a/pytorch_forecasting/layers/_reformer/_encoder.py b/pytorch_forecasting/layers/_reformer/_encoder.py new file mode 100644 index 000000000..689c11e17 --- /dev/null +++ b/pytorch_forecasting/layers/_reformer/_encoder.py @@ -0,0 +1,106 @@ +"""Encoder layers for the Reformer-style attention modules. + +These building blocks are used by Reformer-based sequence models to +process token embeddings into contextualized representations. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ReformerEncoderLayer(nn.Module): + """Single encoder layer combining attention and a conv-based feed-forward. + + Args: + attention (callable): attention module with signature + `attention(query, key, value, attn_mask=None, tau=None, delta=None)` + returning `(output, attn_weights)`. + d_model (int): model hidden dimensionality. + d_ff (int or None): intermediate feed-forward dimensionality. If + `None`, defaults to `4 * d_model`. + dropout (float): dropout probability. + activation (str): activation name, currently "relu" or other for + GELU. + """ + + def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): + super().__init__() + d_ff = d_ff or 4 * d_model + self.attention = attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, attn_mask=None, tau=None, delta=None): + """Forward pass through attention and feed-forward blocks. + + Args: + x (torch.Tensor): input tensor of shape `[batch, seq_len, d_model]`. + attn_mask (torch.Tensor or None): optional attention mask. + tau, delta: optional scheduling/hyperparameters forwarded to the + attention module. + + Returns: + tuple: `(output, attn)` where `output` has shape + `[batch, seq_len, d_model]` and `attn` contains attention + weights or diagnostics from the attention module. + """ + new_x, attn = self.attention(x, x, x, attn_mask=attn_mask, tau=tau, delta=delta) + x = x + self.dropout(new_x) + + y = x = self.norm1(x) + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + + return self.norm2(x + y), attn + + +class ReformerEncoder(nn.Module): + """Stack multiple `ReformerEncoderLayer`s into a full encoder.""" + + def __init__(self, attn_layers, conv_layers=None, norm_layer=None): + super().__init__() + self.attn_layers = nn.ModuleList(attn_layers) + self.conv_layers = ( + nn.ModuleList(conv_layers) if conv_layers is not None else None + ) + self.norm = norm_layer + + def forward(self, x, attn_mask=None, tau=None, delta=None): + """Run the encoder stack. + + Args: + x (torch.Tensor): input tensor `[batch, seq_len, d_model]`. + attn_mask (torch.Tensor or None): optional attention mask. + tau, delta: optional scheduling/hyperparameters forwarded to the + attention layers. + + Returns: + tuple: `(x, attns)` where `x` is the encoded output and `attns` + is a list of attention diagnostics from each layer. + """ + # x [B, L, D] + attns = [] + if self.conv_layers is not None: + for i, (attn_layer, conv_layer) in enumerate( + zip(self.attn_layers, self.conv_layers) + ): + delta = delta if i == 0 else None + x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) + x = conv_layer(x) + attns.append(attn) + x, attn = self.attn_layers[-1](x, tau=tau, delta=None) + attns.append(attn) + else: + for attn_layer in self.attn_layers: + x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) + attns.append(attn) + + if self.norm is not None: + x = self.norm(x) + + return x, attns diff --git a/pytorch_forecasting/layers/_reformer/_lsh_self_attention.py b/pytorch_forecasting/layers/_reformer/_lsh_self_attention.py new file mode 100644 index 000000000..375101de4 --- /dev/null +++ b/pytorch_forecasting/layers/_reformer/_lsh_self_attention.py @@ -0,0 +1,329 @@ +""" +Self-contained LSH Self-Attention implementation. +""" + +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def _look_one_back(x: torch.Tensor) -> torch.Tensor: + """Concatenate each bucket with its left neighbour.""" + x_extra = torch.cat([x[:, -1:], x[:, :-1]], dim=1) + return torch.cat([x_extra, x], dim=2) + + +class LSHAttention(nn.Module): + """Low-level LSH attention operating on already-projected QK and V. + + Args: + bucket_size (int): tokens per bucket half (full bucket = 2*bucket_size). + n_hashes (int): number of hash rounds; outputs are averaged. + causal (bool): mask future tokens. + allow_duplicate_attention (bool): let a token attend to itself twice + (once as Q, once as K). + attend_across_buckets (bool): each bucket also attends to the + preceding bucket (doubles the attended set). + drop_for_hash_rate (float): dropout on hash vectors. + dropout (float): dropout on attention weights. + """ + + def __init__( + self, + bucket_size: int = 64, + n_hashes: int = 8, + causal: bool = False, + allow_duplicate_attention: bool = True, + attend_across_buckets: bool = True, + drop_for_hash_rate: float = 0.0, + dropout: float = 0.0, + ): + super().__init__() + self.bucket_size = bucket_size + self.n_hashes = n_hashes + self.causal = causal + self.allow_duplicate_attention = allow_duplicate_attention + self.attend_across_buckets = attend_across_buckets + self.hash_dropout = nn.Dropout(p=drop_for_hash_rate) + self.attn_dropout = nn.Dropout(p=dropout) + + def _hash_vectors( + self, vecs: torch.Tensor, n_buckets: int, random_rotations: torch.Tensor + ) -> torch.Tensor: + """Assign each vector to one of `n_buckets` buckets via random projections. + + Args: + vecs: (batch, seq_len, dim_head) + n_buckets: total number of buckets + random_rotations: (dim_head, n_hashes, n_buckets // 2) + + Returns: + buckets: (batch, n_hashes * seq_len) integer bucket indices + """ + batch, _, _ = vecs.shape + vecs = self.hash_dropout(vecs) + + rotated = torch.einsum("bld,dhr->blhr", vecs, random_rotations) + + rotated = torch.cat([rotated, -rotated], dim=-1) + buckets = rotated.argmax(dim=-1) + + buckets = buckets.permute(0, 2, 1).reshape(batch, -1) + return buckets + + def forward( + self, + qk: torch.Tensor, + v: torch.Tensor, + input_mask: torch.Tensor | None = None, + ): + """ + Args: + qk: (batch, seq_len, dim_head) — shared Q=K projection + v: (batch, seq_len, dim_head) + input_mask: (batch, seq_len) bool, True = keep + + Returns: + out: (batch, seq_len, dim_head) + """ + batch, seq_len, dim = qk.shape + bucket_size = self.bucket_size + n_hashes = self.n_hashes + + assert seq_len % (bucket_size * 2) == 0, ( + f"Sequence length {seq_len} must be divisible by \ + bucket_size*2={bucket_size*2}. " + "Use ReformerLayer.fit_length() to pad first." + ) + + n_buckets = seq_len // bucket_size + + random_rotations = torch.randn( + dim, n_hashes, n_buckets // 2, device=qk.device, dtype=qk.dtype + ) + qk_norm = F.normalize(qk, p=2, dim=-1) + buckets = self._hash_vectors(qk_norm, n_buckets, random_rotations) + + ticker = torch.arange(n_hashes * seq_len, device=qk.device).unsqueeze(0) + buckets_and_t = seq_len * buckets + (ticker % seq_len) + + _, sorted_idx = buckets_and_t.sort(dim=-1) + _, unsorted_idx = sorted_idx.sort(dim=-1) + + qk_tiled = ( + qk_norm.unsqueeze(1).expand(-1, n_hashes, -1, -1).reshape(batch, -1, dim) + ) + v_tiled = v.unsqueeze(1).expand(-1, n_hashes, -1, -1).reshape(batch, -1, dim) + + exp_idx = sorted_idx.unsqueeze(-1).expand(-1, -1, dim) + qk_sorted = qk_tiled.gather(1, exp_idx) + v_sorted = v_tiled.gather(1, exp_idx) + + chunk_size = bucket_size * 2 + total_len = n_hashes * seq_len + + n_chunks = total_len // chunk_size + qk_chunks = qk_sorted.reshape(batch, n_chunks, chunk_size, dim) + v_chunks = v_sorted.reshape(batch, n_chunks, chunk_size, dim) + + if self.attend_across_buckets: + qk_attend = _look_one_back(qk_chunks) + v_attend = _look_one_back(v_chunks) + else: + qk_attend = qk_chunks + v_attend = v_chunks + + scale = dim**-0.5 + dots = torch.einsum("bcqd,bckd->bcqk", qk_chunks * scale, qk_attend) + sorted_idx_chunks = sorted_idx.reshape(batch, n_chunks, chunk_size) + true_pos_q = sorted_idx_chunks % seq_len + + if self.attend_across_buckets: + true_pos_k_extra = torch.cat( + [true_pos_q[:, -1:], true_pos_q[:, :-1]], dim=1 + ) + true_pos_k = torch.cat([true_pos_k_extra, true_pos_q], dim=2) + else: + true_pos_k = true_pos_q + + if not self.allow_duplicate_attention: + dupe_mask = true_pos_q.unsqueeze(-1) == true_pos_k.unsqueeze(-2) + dots.masked_fill_(dupe_mask, float("-inf")) + + if self.causal: + causal_mask = true_pos_q.unsqueeze(-1) < true_pos_k.unsqueeze(-2) + dots.masked_fill_(causal_mask, float("-inf")) + + if input_mask is not None: + mask_tiled = ( + input_mask.unsqueeze(1).expand(-1, n_hashes, -1).reshape(batch, -1) + ) + mask_sorted_idx = sorted_idx.reshape(batch, n_chunks, chunk_size) + if self.attend_across_buckets: + k_extra_idx = torch.cat( + [mask_sorted_idx[:, -1:], mask_sorted_idx[:, :-1]], dim=1 + ) + k_idx = torch.cat([k_extra_idx, mask_sorted_idx], dim=2) + else: + k_idx = mask_sorted_idx + orig_k_pos = k_idx % seq_len + k_mask = mask_tiled.gather(1, orig_k_pos.reshape(batch, -1)).reshape( + batch, n_chunks, -1 + ) + dots.masked_fill_(~k_mask.unsqueeze(2).bool(), float("-inf")) + + attn = F.softmax(dots, dim=-1) + attn = self.attn_dropout(attn) + + out_chunks = torch.einsum("bcqk,bckd->bcqd", attn, v_attend) + + out_sorted = out_chunks.reshape(batch, total_len, dim) + + exp_unsort_idx = unsorted_idx.unsqueeze(-1).expand(-1, -1, dim) + out_unsorted = out_sorted.gather(1, exp_unsort_idx) + + out = out_unsorted.reshape(batch, n_hashes, seq_len, dim).mean(dim=1) + + return out + + +class LSHSelfAttention(nn.Module): + """Multi-head LSH Self-Attention layer. + + Args: + dim (int): model dimensionality (input and output). + heads (int): number of attention heads. + bucket_size (int): tokens per half-bucket; sequence length must be + divisible by ``bucket_size * 2``. + n_hashes (int): number of LSH hash rounds to average over. + causal (bool): enable causal (autoregressive) masking. + dim_head (int | None): dimensionality per head; defaults to + ``dim // heads``. + attn_chunks (int): process attention in this many sequential chunks + to trade speed for memory (1 = no chunking). + dropout (float): dropout on attention weights. + post_attn_dropout (float): dropout after the output projection. + allow_duplicate_attention (bool): passed to ``LSHAttention``. + attend_across_buckets (bool): passed to ``LSHAttention``. + num_mem_kv (int): number of persistent memory key-value pairs appended + to every sequence (similar to "all-attention" paper). + """ + + def __init__( + self, + dim: int, + heads: int = 8, + bucket_size: int = 64, + n_hashes: int = 8, + causal: bool = False, + dim_head: int | None = None, + attn_chunks: int = 1, + dropout: float = 0.0, + post_attn_dropout: float = 0.0, + allow_duplicate_attention: bool = True, + attend_across_buckets: bool = True, + num_mem_kv: int = 0, + ): + super().__init__() + dim_head = dim_head or (dim // heads) + inner_dim = dim_head * heads + + self.heads = heads + self.dim_head = dim_head + self.bucket_size = bucket_size + self.attn_chunks = attn_chunks + self.num_mem_kv = num_mem_kv + + self.to_qk = nn.Linear(dim, inner_dim, bias=False) + self.to_v = nn.Linear(dim, inner_dim, bias=False) + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim, bias=False), + nn.Dropout(post_attn_dropout), + ) + + self.lsh_attn = LSHAttention( + bucket_size=bucket_size, + n_hashes=n_hashes, + causal=causal, + allow_duplicate_attention=allow_duplicate_attention, + attend_across_buckets=attend_across_buckets, + dropout=dropout, + ) + + if num_mem_kv > 0: + self.mem_kv = nn.Parameter(torch.randn(1, num_mem_kv, dim_head * 2)) + else: + self.mem_kv = None + + self.dropout = nn.Dropout(dropout) + + def forward( + self, + x: torch.Tensor, + input_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Args: + x: (batch, seq_len, dim) + input_mask: (batch, seq_len) bool tensor, True = valid token. + + Returns: + (batch, seq_len, dim) + """ + batch, seq_len, _ = x.shape + heads = self.heads + dim_h = self.dim_head + + qk = self.to_qk(x) + v = self.to_v(x) + + def split_heads(t): + return t.reshape(batch, seq_len, heads, dim_h).permute(0, 2, 1, 3) + + qk = split_heads(qk) + v = split_heads(v) + + if self.mem_kv is not None: + mem = self.mem_kv.expand(batch, -1, -1) + mem_k, mem_v = mem.chunk(2, dim=-1) + mem_k = mem_k.unsqueeze(1).expand(-1, heads, -1, -1) + mem_v = mem_v.unsqueeze(1).expand(-1, heads, -1, -1) + qk = torch.cat([qk, mem_k], dim=2) + v = torch.cat([v, mem_v], dim=2) + if input_mask is not None: + mem_mask = input_mask.new_ones(batch, self.num_mem_kv) + input_mask = torch.cat([input_mask, mem_mask], dim=1) + + total_seq = qk.shape[2] + qk_flat = qk.reshape(batch * heads, total_seq, dim_h) + v_flat = v.reshape(batch * heads, total_seq, dim_h) + + if input_mask is not None: + mask_flat = ( + input_mask.unsqueeze(1).expand(-1, heads, -1).reshape(batch * heads, -1) + ) + else: + mask_flat = None + + chunk_size = math.ceil(batch * heads / self.attn_chunks) + out_chunks = [] + for i in range(0, batch * heads, chunk_size): + sl = slice(i, i + chunk_size) + out_chunk = self.lsh_attn( + qk_flat[sl], + v_flat[sl], + input_mask=mask_flat[sl] if mask_flat is not None else None, + ) + out_chunks.append(out_chunk) + + out = torch.cat(out_chunks, dim=0) + + out = out[:, :seq_len, :] + + out = out.reshape(batch, heads, seq_len, dim_h) + out = out.permute(0, 2, 1, 3).reshape(batch, seq_len, heads * dim_h) + + return self.to_out(out) diff --git a/pytorch_forecasting/layers/_reformer/_reformer_layer.py b/pytorch_forecasting/layers/_reformer/_reformer_layer.py new file mode 100644 index 000000000..6636500f0 --- /dev/null +++ b/pytorch_forecasting/layers/_reformer/_reformer_layer.py @@ -0,0 +1,77 @@ +"""Reformer Layer implementation.""" + +import torch +import torch.nn as nn + +from pytorch_forecasting.layers._reformer._lsh_self_attention import LSHSelfAttention + + +class ReformerLayer(nn.Module): + """ + ReformerLayer with Locality-Sensitive Hashing (LSH) Self-Attention. + Args: + attention: unused (kept for API compatibility). + d_model (int): input/output dimensionality. + n_heads (int): number of attention heads. + d_keys, d_values: unused placeholders for key/value dims. + causal (bool): whether attention should be causal. + bucket_size (int): LSH bucket size, used for padding computation. + n_hashes (int): number of hash rounds used by LSH attention. + """ + + def __init__( + self, + attention, + d_model, + n_heads, + d_keys=None, + d_values=None, + causal=False, + bucket_size=4, + n_hashes=4, + ): + super().__init__() + + self.bucket_size = bucket_size + self.attn = LSHSelfAttention( + dim=d_model, + heads=n_heads, + bucket_size=bucket_size, + n_hashes=n_hashes, + causal=causal, + ) + + def fit_length(self, queries): + """Pad `queries` so its sequence length is divisible by `bucket_size*2`. + + Args: + queries (torch.Tensor): tensor `[batch, seq_len, channels]`. + + Returns: + torch.Tensor: padded tensor with adjusted sequence length. + """ + B, N, C = queries.shape + if N % (self.bucket_size * 2) == 0: + return queries + else: + fill_len = (self.bucket_size * 2) - (N % (self.bucket_size * 2)) + return torch.cat( + [queries, torch.zeros([B, fill_len, C]).to(queries.device)], dim=1 + ) + + def forward(self, queries, keys, values, attn_mask, tau, delta): + """Run LSH self-attention on `queries` and return the processed + tensor. + + Args: + queries (torch.Tensor): input `[batch, seq_len, dim]`. + keys, values, attn_mask, tau, delta: accepted for API + compatibility but not used by this wrapper. + + Returns: + tuple: `(output, None)` where `output` has shape + `[batch, seq_len, dim]`. + """ + B, N, C = queries.shape + queries = self.attn(self.fit_length(queries))[:, :N, :] + return queries, None diff --git a/pytorch_forecasting/models/__init__.py b/pytorch_forecasting/models/__init__.py index dc635b261..a05eac8ee 100644 --- a/pytorch_forecasting/models/__init__.py +++ b/pytorch_forecasting/models/__init__.py @@ -14,6 +14,7 @@ from pytorch_forecasting.models.nbeats import NBeats, NBeatsKAN from pytorch_forecasting.models.nhits import NHiTS from pytorch_forecasting.models.nn import GRU, LSTM, MultiEmbedding, get_rnn +from pytorch_forecasting.models.reformer import Reformer from pytorch_forecasting.models.rnn import RecurrentNetwork from pytorch_forecasting.models.temporal_fusion_transformer import ( TemporalFusionTransformer, @@ -42,4 +43,5 @@ "TiDEModel", "TimeXer", "xLSTMTime", + "Reformer", ] diff --git a/pytorch_forecasting/models/reformer/__init__.py b/pytorch_forecasting/models/reformer/__init__.py new file mode 100644 index 000000000..b129299b7 --- /dev/null +++ b/pytorch_forecasting/models/reformer/__init__.py @@ -0,0 +1,6 @@ +"""Reformer for forecasting timeseries.""" + +from pytorch_forecasting.models.reformer.reformer_pkg_v2 import Reformer_pkg_v2 +from pytorch_forecasting.models.reformer.reformer_v2 import Reformer + +__all__ = ["Reformer", "Reformer_pkg_v2"] diff --git a/pytorch_forecasting/models/reformer/reformer_pkg_v2.py b/pytorch_forecasting/models/reformer/reformer_pkg_v2.py new file mode 100644 index 000000000..ac484e656 --- /dev/null +++ b/pytorch_forecasting/models/reformer/reformer_pkg_v2.py @@ -0,0 +1,109 @@ +""" +Metadata container for Reformer v2. +""" + +from pytorch_forecasting.base._base_pkg import Base_pkg + + +class Reformer_pkg_v2(Base_pkg): + """Reformer metadata container.""" + + _tags = { + "info:name": "Reformer", + "authors": ["lucifer4073"], + "capability:exogenous": True, + "capability:multivariate": True, + "capability:pred_int": True, + "capability:flexible_history_length": False, + } + + @classmethod + def get_cls(cls): + """Get model class.""" + from pytorch_forecasting.models.reformer.reformer_v2 import Reformer + + return Reformer + + @classmethod + def get_datamodule_cls(cls): + """Get the underlying DataModule class.""" + from pytorch_forecasting.data._tslib_data_module import TslibDataModule + + return TslibDataModule + + @classmethod + def get_test_train_params(cls): + """Return testing parameter settings for the trainer. + + Returns + ------- + params : dict or list of dict, default = {} + Parameters to create testing instances of the class + Each dict are parameters to construct an "interesting" test instance, i.e., + `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. + `create_test_instance` uses the first (or only) dictionary in `params` + """ + + params = [ + {}, + dict( + d_model=64, + n_heads=4, + ), + dict(datamodule_cfg=dict(context_length=12, prediction_length=3)), + dict( + d_model=32, + n_heads=2, + bucket_size=2, + n_hashes=2, + datamodule_cfg=dict( + context_length=12, + prediction_length=3, + add_relative_time_idx=False, + ), + ), + dict( + d_model=128, + e_layers=1, + d_ff=128, + datamodule_cfg=dict(context_length=16, prediction_length=4), + ), + dict( + n_heads=2, + e_layers=1, + bucket_size=4, + ), + dict( + d_model=256, + n_heads=8, + e_layers=3, + d_ff=1024, + bucket_size=8, + n_hashes=4, + activation="gelu", + dropout=0.2, + ), + dict( + d_model=32, + n_heads=2, + e_layers=1, + d_ff=64, + bucket_size=2, + n_hashes=2, + activation="relu", + dropout=0.05, + datamodule_cfg=dict( + context_length=16, + prediction_length=4, + ), + ), + ] + default_dm_cfg = {"context_length": 12, "prediction_length": 4} + + for param in params: + current_dm_cfg = param.get("datamodule_cfg", {}) + default_dm_cfg.update(current_dm_cfg) + + param["datamodule_cfg"] = default_dm_cfg + + return params diff --git a/pytorch_forecasting/models/reformer/reformer_v2.py b/pytorch_forecasting/models/reformer/reformer_v2.py new file mode 100644 index 000000000..abf459337 --- /dev/null +++ b/pytorch_forecasting/models/reformer/reformer_v2.py @@ -0,0 +1,404 @@ +""" +Reformer: Efficient Transformer for Long-Range Sequence Modeling +---------------------------------------------------------------- +""" + +from typing import Any, Optional, Union +import warnings as warn + +import torch +import torch.nn as nn +from torch.optim import Optimizer + +from pytorch_forecasting.models.base._tslib_base_model_v2 import TslibBaseModel + + +class Reformer(TslibBaseModel): + """ + An implementation of the Reformer model for pytorch-forecasting-v2. + + Parameters + ---------- + loss : nn.Module + Loss function to use for training. + enc_in : int, optional + Number of input features for the encoder. If not provided, it will be set + to the number of continuous features in the dataset. + d_model : int, default=512 + Dimension of the model embeddings and hidden representations. + n_heads : int, default=8 + Number of attention heads in the LSH multi-head attention mechanism. + e_layers : int, default=2 + Number of encoder layers in the Reformer architecture. + d_ff : int, default=2048 + Dimension of the feed-forward network inside each encoder layer. + dropout : float, default=0.1 + Dropout rate applied throughout the model for regularization. + activation : str, default='gelu' + Activation function used in the feed-forward network. Common choices + are ``'relu'`` and ``'gelu'``. + embed : str, default='timeF' + Type of time feature embedding to use. Use ``'timeF'`` for time-frequency + embeddings or ``'fixed'`` / ``'learned'`` for positional embeddings. + task_name : str, default='long_term_forecast' + Forecasting task type. Either ``'long_term_forecast'`` or + ``'short_term_forecast'``. Short-term forecasting applies instance + normalization before encoding and denormalizes the output. + bucket_size : int, default=4 + Size of the LSH attention buckets. Queries and keys are hashed into + buckets of this size before computing attention within each bucket. + n_hashes : int, default=4 + Number of LSH hashing rounds. More rounds improve attention approximation + quality at the cost of additional computation. + logging_metrics : list[nn.Module] or None, default=None + List of additional metrics to log during training, validation, and testing. + optimizer : Optimizer or str or None, default='adam' + Optimizer to use for training. Can be a string name (e.g., ``'adam'``, + ``'sgd'``) or an instantiated :class:`torch.optim.Optimizer`. + optimizer_params : dict or None, default=None + Keyword arguments passed to the optimizer constructor. If ``None``, + the optimizer's default parameters are used. + lr_scheduler : str or None, default=None + Learning rate scheduler to apply after each epoch. If ``None``, no + scheduler is used. + lr_scheduler_params : dict or None, default=None + Keyword arguments passed to the learning rate scheduler constructor. + If ``None``, the scheduler's default parameters are used. + metadata : dict or None, default=None + Metadata dictionary produced by ``TslibDataModule``. Must contain at + minimum a ``'freq'`` key describing the time series frequency (e.g., + ``'h'`` for hourly). Used to configure time-feature embeddings and to + infer dataset-level properties such as the number of continuous features. + + References + ---------- + [1] Kitaev, N., Kaiser, Ł., & Levskaya, A. (2020). Reformer: The Efficient + Transformer. https://arxiv.org/abs/2001.04451 + [2] https://github.com/thuml/Time-Series-Library/blob/main/models/Reformer.py + + Notes + ----- + This implementation handles only continuous variables. + """ + + @classmethod + def _pkg(cls): + """ + Return the package class that contains the Reformer model implementation. + + Returns + ------- + Reformer_pkg_v2 + The package class used to instantiate the underlying Reformer network. + """ + from pytorch_forecasting.models.reformer.reformer_pkg_v2 import Reformer_pkg_v2 + + return Reformer_pkg_v2 + + def __init__( + self, + loss: nn.Module, + enc_in: int = None, + d_model: int = 512, + n_heads: int = 8, + e_layers: int = 2, + d_ff: int = 2048, + dropout: float = 0.1, + activation: str = "gelu", + embed: str = "timeF", + task_name: str = "long_term_forecast", + bucket_size: int = 4, + n_hashes: int = 4, + logging_metrics: list[nn.Module] | None = None, + optimizer: Optimizer | str | None = "adam", + optimizer_params: dict | None = None, + lr_scheduler: str | None = None, + lr_scheduler_params: dict | None = None, + metadata: dict | None = None, + **kwargs: Any, + ): + """ + Initialize the Reformer model. + + Calls the parent ``TslibBaseModel.__init__`` to handle shared setup + (loss, optimizer, scheduler, metadata), then stores all Reformer-specific + hyperparameters and calls ``_init_network`` to build the model layers. + + Parameters + ---------- + loss : nn.Module + Loss function to use for training. + enc_in : int, optional + Number of encoder input features. Defaults to ``cont_dim`` when + ``None``. + d_model : int, default=512 + Hidden dimension of all model layers. + n_heads : int, default=8 + Number of LSH attention heads. + e_layers : int, default=2 + Number of stacked encoder layers. + d_ff : int, default=2048 + Inner dimension of the position-wise feed-forward sublayers. + dropout : float, default=0.1 + Dropout probability. + activation : str, default='gelu' + Non-linearity for the feed-forward sublayers. + embed : str, default='timeF' + Time-feature embedding strategy. + task_name : str, default='long_term_forecast' + Controls whether short- or long-term forecasting logic is used. + bucket_size : int, default=4 + LSH bucket size for approximating full attention. + n_hashes : int, default=4 + Number of LSH hashing rounds. + logging_metrics : list[nn.Module] or None, default=None + Extra metrics logged during training and evaluation. + optimizer : Optimizer or str or None, default='adam' + Training optimizer. + optimizer_params : dict or None, default=None + Optimizer constructor arguments. + lr_scheduler : str or None, default=None + Learning rate scheduler name. + lr_scheduler_params : dict or None, default=None + Scheduler constructor arguments. + metadata : dict or None, default=None + Dataset metadata from ``TslibDataModule``. Must contain ``'freq'``. + **kwargs : Any + Additional keyword arguments forwarded to ``TslibBaseModel``. + """ + super().__init__( + loss=loss, + logging_metrics=logging_metrics, + optimizer=optimizer, + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + lr_scheduler_params=lr_scheduler_params, + metadata=metadata, + ) + self.d_model = d_model + self.n_heads = n_heads + self.e_layers = e_layers + self.d_ff = d_ff + self.dropout = dropout + self.activation = activation + self.embed = embed + self.task_name = task_name + self.bucket_size = bucket_size + self.n_hashes = n_hashes + self.enc_in = enc_in + self.freq = metadata["freq"] + self.save_hyperparameters(ignore=["loss", "logging_metrics", "metadata"]) + + self._init_network() + + def _init_network(self): + """ + Build and register all sub-modules of the Reformer architecture. + + Constructs the following components: + + - ``enc_embedding`` : :class:`DataEmbedding` that projects raw input + features (plus optional time marks) into the ``d_model``-dimensional + hidden space. + - ``encoder`` : :class:`ReformerEncoder` consisting of ``e_layers`` + stacked :class:`ReformerEncoderLayer` blocks, each containing LSH + self-attention (:class:`ReformerLayer`) and a position-wise + feed-forward sublayer, followed by layer normalisation. + - ``projection`` : :class:`torch.nn.Linear` that maps encoder outputs + from ``d_model`` dimensions to ``target_dim`` output channels. + + """ + from pytorch_forecasting.layers import ( + DataEmbedding, + ReformerEncoder, + ReformerEncoderLayer, + ReformerLayer, + ) + + self.enc_in = self.enc_in or self.cont_dim + self.enc_embedding = DataEmbedding( + self.enc_in, + self.d_model, + self.embed, + self.freq, + self.dropout, + ) + + self.encoder = ReformerEncoder( + [ + ReformerEncoderLayer( + ReformerLayer( + None, + self.d_model, + self.n_heads, + bucket_size=self.bucket_size, + n_hashes=self.n_hashes, + ), + self.d_model, + self.d_ff, + dropout=self.dropout, + activation=self.activation, + ) + for l in range(self.e_layers) + ], + norm_layer=torch.nn.LayerNorm(self.d_model), + ) + + self.projection = nn.Linear(self.d_model, self.target_dim, bias=True) + + def _long_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): + """ + Perform long-term forecasting via encoder-only pass with a placeholder. + + Parameters + ---------- + x_enc : torch.Tensor + Historical encoder input of shape ``(batch, context_length, enc_in)``. + x_mark_enc : torch.Tensor or None + Time-feature marks for the encoder sequence, shape + ``(batch, context_length, time_features)``, or ``None`` when + ``embed != 'timeF'``. + x_dec : torch.Tensor + Decoder input (used only to extract the future placeholder tokens), + shape ``(batch, label_length + prediction_length, enc_in)``. + x_mark_dec : torch.Tensor or None + Time-feature marks for the decoder sequence, or ``None``. + + Returns + ------- + torch.Tensor + Full encoder output of shape + ``(batch, context_length + prediction_length, target_dim)``. + Slice ``[:, -prediction_length:, :]`` to obtain the forecast. + """ + x_enc = torch.cat([x_enc, x_dec[:, -self.prediction_length :, :]], dim=1) + if x_mark_enc is not None: + x_mark_enc = torch.cat( + [x_mark_enc, x_mark_dec[:, -self.prediction_length :, :]], dim=1 + ) + + enc_out = self.enc_embedding(x_enc, x_mark_enc) + enc_out, attns = self.encoder(enc_out, attn_mask=None) + dec_out = self.projection(enc_out) + + return dec_out + + def _short_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): + """ + Perform short-term forecasting with instance normalization. + + Parameters + ---------- + x_enc : torch.Tensor + Historical encoder input of shape ``(batch, context_length, enc_in)``. + x_mark_enc : torch.Tensor or None + Time-feature marks for the encoder sequence, or ``None``. + x_dec : torch.Tensor + Decoder input for extracting future placeholder tokens, shape + ``(batch, label_length + prediction_length, enc_in)``. + x_mark_dec : torch.Tensor or None + Time-feature marks for the decoder sequence, or ``None``. + + Returns + ------- + torch.Tensor + Denormalized forecast of shape + ``(batch, context_length + prediction_length, target_dim)``. + Slice ``[:, -prediction_length:, :]`` to obtain the forecast. + """ + mean_enc = x_enc[:, :, 0:1].mean(1, keepdim=True).detach() + std_enc = torch.sqrt( + torch.var(x_enc[:, :, 0:1], dim=1, keepdim=True, unbiased=False) + 1e-5 + ).detach() + all_mean = x_enc.mean(1, keepdim=True).detach() + all_std = torch.sqrt( + torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5 + ).detach() + x_enc = (x_enc - all_mean) / all_std + + x_enc = torch.cat([x_enc, x_dec[:, -self.prediction_length :, :]], dim=1) + if x_mark_enc is not None: + x_mark_enc = torch.cat( + [x_mark_enc, x_mark_dec[:, -self.prediction_length :, :]], dim=1 + ) + + enc_out = self.enc_embedding(x_enc, x_mark_enc) + enc_out, _ = self.encoder(enc_out, attn_mask=None) + dec_out = self.projection(enc_out) + + dec_out = dec_out * std_enc + mean_enc + return dec_out + + def _forecast(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Route the input batch to the appropriate forecasting method. + + Parameters + ---------- + x : dict[str, torch.Tensor] + Batch dictionary expected to contain the following keys: + + - ``'history_cont'`` : ``(batch, context_length, enc_in)`` – historical + continuous covariates used as encoder input. + - ``'future_cont'`` : ``(batch, prediction_length, future_features)`` – + future continuous covariates used as decoder placeholder input. + - ``'history_time_idx'`` *(optional)* : ``(batch, context_length)`` – + integer or float time indices for encoder time-feature embedding. + - ``'future_time_idx'`` *(optional)* : ``(batch, prediction_length)`` – + integer or float time indices for decoder time-feature embedding. + + Returns + ------- + torch.Tensor + Raw model output of shape + ``(batch, context_length + prediction_length, target_dim)``. + """ + x_enc = x["history_cont"] + x_dec = x["future_cont"] + if x_enc.shape[-1] != x_dec.shape[-1]: + diff = x_enc.shape[-1] - x_dec.shape[-1] + if diff > 0: + x_dec = torch.nn.functional.pad(x_dec, (0, diff)) + else: + x_dec = x_dec[..., : x_enc.shape[-1]] + + if self.embed == "timeF": + x_mark_enc = x["history_time_idx"].unsqueeze(-1).float() + x_mark_dec = x["future_time_idx"].unsqueeze(-1).float() + else: + x_mark_enc = None + x_mark_dec = None + + if self.task_name == "short_term_forecast": + dec_out = self._short_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) + return dec_out + + dec_out = self._long_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) + return dec_out + + def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Forward pass for Reformer model. + + Parameters + ---------- + x : dict[str, torch.Tensor] + Batch dictionary. See :meth:`_forecast` for the expected keys. + Additionally, if the key ``'target_scale'`` is present its value + is forwarded to :meth:`transform_output` for inverse scaling. + + Returns + ------- + dict[str, torch.Tensor] + + - ``'prediction'`` : ``(batch, prediction_length, target_dim)`` – + the model's forecast, optionally rescaled to the original target + units. + """ + out = self._forecast(x) + prediction = out[:, : self.prediction_length, :] + + if "target_scale" in x: + prediction = self.transform_output(prediction, x["target_scale"]) + + return {"prediction": prediction}