Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pytorch_forecasting/layers/_embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
DataEmbedding_inverted,
)
from pytorch_forecasting.layers._embeddings._en_embedding import EnEmbedding
from pytorch_forecasting.layers._embeddings._patch_embedding import PatchEmbedding
from pytorch_forecasting.layers._embeddings._positional_embedding import (
PositionalEmbedding,
)
from pytorch_forecasting.layers._embeddings._sub_nn import embedding_cat_variables

__all__ = [
"PatchEmbedding",
"PositionalEmbedding",
"DataEmbedding_inverted",
"EnEmbedding",
Expand Down
114 changes: 114 additions & 0 deletions pytorch_forecasting/layers/_embeddings/_patch_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""
Patch Embedding Layer for PatchTST.

Adapted from thuml/Time-Series-Library:
https://github.com/thuml/Time-Series-Library/blob/main/layers/Embed.py
Paper: https://arxiv.org/pdf/2211.14730.pdf
"""

import torch
import torch.nn as nn

from pytorch_forecasting.layers._embeddings._positional_embedding import (
PositionalEmbedding,
)


class PatchEmbedding(nn.Module):
"""
Patch Embedding for time series data.

Splits each variable's time series into overlapping patches of fixed length,
then projects each patch into a d_model-dimensional vector space. This is
the core input representation used by PatchTST.

The patching operation treats each variable (channel) independently,
which is referred to as the "Channel Independence" (CI) assumption.

Parameters
----------
d_model : int
Dimension of the output embedding for each patch.
patch_len : int
Length of each patch (number of time steps per patch).
stride : int
Step size between consecutive patches (controls overlap).
padding : int
Amount of replication padding applied to the right end of
the time series before patching. Setting ``padding = stride``
ensures no time-step information is dropped.
dropout : float
Dropout probability applied after the embedding.

Notes
-----
The forward method merges the batch and variable dimensions together
so that the Transformer encoder can treat each (batch, variable) pair
as an independent sequence of patch tokens. The returned ``n_vars``
value is needed by the caller to un-merge these dimensions after
the encoder has run.
"""

def __init__(
self,
d_model: int,
patch_len: int,
stride: int,
padding: int,
dropout: float,
):
super().__init__()
self.patch_len = patch_len
self.stride = stride

# Pad the right side of the input so no data is lost at the boundary
self.padding_patch_layer = nn.ReplicationPad1d((0, padding))

# Linear projection: each patch of length `patch_len` → vector of size `d_model`
self.value_embedding = nn.Linear(patch_len, d_model, bias=False)

# Sinusoidal positional embedding added on top of patch projections
self.position_embedding = PositionalEmbedding(d_model)

self.dropout = nn.Dropout(dropout)

def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, int]:
"""
Forward pass of the patch embedding.

Parameters
----------
x : torch.Tensor
Input tensor of shape ``(batch_size, n_vars, seq_len)``.
Note: the time dimension is the *last* axis here (already transposed
from the ``(batch, time, vars)`` format used elsewhere).

Returns
-------
x : torch.Tensor
Embedded patches of shape
``(batch_size * n_vars, n_patches, d_model)``.
n_vars : int
Number of input variables (channels). The caller needs this
to reshape the encoder output back to
``(batch_size, n_vars, n_patches, d_model)``.
"""
# Record n_vars before merging dimensions
n_vars = x.shape[1]

# Pad the right end so the unfold produces a clean set of patches
x = self.padding_patch_layer(x)

# Slice into overlapping patches: result shape is
# (batch_size, n_vars, n_patches, patch_len)
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)

# Merge batch and variable dimensions so the Transformer sees each
# (sample, variable) pair as an independent sequence of patch tokens.
# Shape becomes (batch_size * n_vars, n_patches, patch_len)
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))

# Project each patch + add positional encoding
x = self.value_embedding(x) + self.position_embedding(x)

return self.dropout(x), n_vars
6 changes: 5 additions & 1 deletion pytorch_forecasting/layers/_encoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,9 @@

from pytorch_forecasting.layers._encoders._encoder import Encoder
from pytorch_forecasting.layers._encoders._encoder_layer import EncoderLayer
from pytorch_forecasting.layers._encoders._self_attn_encoder import SelfAttnEncoder
from pytorch_forecasting.layers._encoders._self_attn_encoder_layer import (
SelfAttnEncoderLayer,
)

__all__ = ["Encoder", "EncoderLayer"]
__all__ = ["Encoder", "EncoderLayer", "SelfAttnEncoder", "SelfAttnEncoderLayer"]
63 changes: 63 additions & 0 deletions pytorch_forecasting/layers/_encoders/_self_attn_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""
Self-attention-only Encoder for PatchTST and similar encoder-only models.

Adapted from thuml/Time-Series-Library:
https://github.com/thuml/Time-Series-Library/blob/main/layers/Transformer_EncDec.py
"""

import torch
import torch.nn as nn


class SelfAttnEncoder(nn.Module):
"""
A stack of ``SelfAttnEncoderLayer`` modules.

Intended for models (like PatchTST) that use an encoder-only architecture
where every layer processes its input with self-attention only (no decoder,
no cross-attention).

Parameters
----------
attn_layers : list[nn.Module]
A list of ``SelfAttnEncoderLayer`` instances to stack.
norm_layer : nn.Module, optional
Normalization applied after the final encoder layer.
PatchTST uses ``nn.Sequential(Transpose(1, 2), nn.BatchNorm1d(d_model),
Transpose(1, 2))`` here so that BatchNorm operates on the channel
dimension. Defaults to ``None``.
"""

def __init__(self, attn_layers, norm_layer=None):
super().__init__()
self.attn_layers = nn.ModuleList(attn_layers)
self.norm = norm_layer

def forward(self, x, attn_mask=None, tau=None, delta=None):
"""
Forward pass through all stacked encoder layers.

Parameters
----------
x : torch.Tensor
Input of shape ``(batch_size, seq_len, d_model)``.
attn_mask : optional
Attention mask passed to each layer (usually ``None`` for PatchTST).

Returns
-------
x : torch.Tensor
Encoded representation of shape ``(batch_size, seq_len, d_model)``.
attns : list[torch.Tensor or None]
One attention-weight tensor per layer (or ``None`` if
``output_attention=False`` in ``FullAttention``).
"""
attns = []
for attn_layer in self.attn_layers:
x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
attns.append(attn)

if self.norm is not None:
x = self.norm(x)

return x, attns
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""
Self-attention Encoder Layer for PatchTST and similar encoder-only models.

Adapted from thuml/Time-Series-Library:
https://github.com/thuml/Time-Series-Library/blob/main/layers/Transformer_EncDec.py
"""

import torch.nn as nn
import torch.nn.functional as F


class SelfAttnEncoderLayer(nn.Module):
"""
A single Transformer encoder layer using self-attention only.

Unlike the ``EncoderLayer`` in ``_encoder_layer.py`` (which uses
both self- and cross-attention for TimeXer), this layer is
a standard encoder block with:
- Multi-head self-attention
- Position-wise feed-forward network (implemented via two 1-D convolutions)
- Layer normalisation and dropout in both sub-layers (Pre-Net residuals)

Parameters
----------
attention : nn.Module
An ``AttentionLayer`` wrapping the inner attention mechanism
(typically ``FullAttention``).
d_model : int
Dimension of the model (embedding size).
d_ff : int, optional
Hidden dimension of the feed-forward network.
Defaults to ``4 * d_model``.
dropout : float
Dropout probability in attention and feed-forward sub-layers.
Defaults to 0.1.
activation : str
Activation function for the feed-forward network.
Must be ``"relu"`` or ``"gelu"``. Defaults to ``"relu"``.
"""

def __init__(
self,
attention,
d_model: int,
d_ff: int = None,
dropout: float = 0.1,
activation: str = "relu",
):
super().__init__()
d_ff = d_ff or 4 * d_model
self.attention = attention
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.activation = F.relu if activation == "relu" else F.gelu

def forward(self, x, attn_mask=None, tau=None, delta=None):
"""
Forward pass of the encoder layer.

Parameters
----------
x : torch.Tensor
Input of shape ``(batch_size, seq_len, d_model)``.
attn_mask : optional
Attention mask (usually ``None`` for PatchTST).

Returns
-------
x : torch.Tensor
Output of shape ``(batch_size, seq_len, d_model)``.
attn : torch.Tensor or None
Attention weights (returned only if ``output_attention=True``
in the inner ``FullAttention``; otherwise ``None``).
"""
# --- Self-attention sub-layer ---
new_x, attn = self.attention(x, x, x, attn_mask=attn_mask, tau=tau, delta=delta)
x = x + self.dropout(new_x)
y = x = self.norm1(x)

# --- Feed-forward sub-layer (via 1-D conv, equivalent to linear) ---
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))

return self.norm2(x + y), attn
8 changes: 8 additions & 0 deletions pytorch_forecasting/models/patchtst/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""
PatchTST model for time series forecasting.
"""

from pytorch_forecasting.models.patchtst._patchtst_pkg_v2 import PatchTST_pkg_v2
from pytorch_forecasting.models.patchtst._patchtst_v2 import PatchTST

__all__ = ["PatchTST", "PatchTST_pkg_v2"]
Loading
Loading