diff --git a/i6_models/assemblies/transducer/__init__.py b/i6_models/assemblies/transducer/__init__.py new file mode 100644 index 00000000..04c78687 --- /dev/null +++ b/i6_models/assemblies/transducer/__init__.py @@ -0,0 +1,2 @@ +from .prediction_network import * +from .joint_network import * diff --git a/i6_models/assemblies/transducer/joint_network.py b/i6_models/assemblies/transducer/joint_network.py new file mode 100644 index 00000000..b4a9ec9b --- /dev/null +++ b/i6_models/assemblies/transducer/joint_network.py @@ -0,0 +1,81 @@ +__all__ = ["TransducerJointNetworkV1Config", "TransducerJointNetworkV1"] + +from dataclasses import dataclass +from typing import Any, Dict, Tuple, Union + +import torch +from torch import nn + +from i6_models.config import ModelConfiguration +from i6_models.parts.ffnn import FeedForwardBlockV1Config, FeedForwardBlockV1 + + +@dataclass +class TransducerJointNetworkV1Config(ModelConfiguration): + """ + Configuration for the Transducer Joint Network. + + Attributes: + ffnn_cfg: Configuration for the internal feed-forward network. + """ + + ffnn_cfg: FeedForwardBlockV1Config + + +class TransducerJointNetworkV1(nn.Module): + def __init__( + self, + cfg: TransducerJointNetworkV1Config, + ) -> None: + super().__init__() + self.ffnn = FeedForwardBlockV1(cfg.ffnn_cfg) + self.output_dim = self.ffnn.output_dim + + def forward( + self, + source_encodings: torch.Tensor, # [1, T, E] + target_encodings: torch.Tensor, # [B, S, P] + ) -> torch.Tensor: # [B, T, S, F] + """ + Forward pass for recognition. + """ + combined_encodings = source_encodings.unsqueeze(2) + target_encodings.unsqueeze(1) + output = self.ffnn(combined_encodings) # [B, T, S, F] + + if not self.training: + output = torch.log_softmax(output, dim=-1) # [B, T, S, F] + return output + + def forward_viterbi( + self, + source_encodings: torch.Tensor, # [B, T, E] + source_lengths: torch.Tensor, # [B] + target_encodings: torch.Tensor, # [B, T, P] + target_lengths: torch.Tensor, # [B] + ) -> torch.Tensor: # [B, T, F] + """ + Forward pass for Viterbi training. + """ + combined_encodings = source_encodings + target_encodings + output = self.ffnn(combined_encodings) # [B, T, F] + if not self.training: + output = torch.log_softmax(output, dim=-1) # [B, T, F] + return output, source_lengths, target_lengths + + def forward_fullsum( + self, + source_encodings: torch.Tensor, # [B, T, E] + source_lengths: torch.Tensor, # [B] + target_encodings: torch.Tensor, # [B, S+1, P] + target_lengths: torch.Tensor, # [B] + ) -> torch.Tensor: # [B, T, S+1, F] + """ + Forward pass for fullsum training. Returns output with shape [B, T, S+1, F]. + """ + + # additive combination + combined_encodings = source_encodings.unsqueeze(2) + target_encodings.unsqueeze(1) + + # Pass through FFNN + output = self.ffnn(combined_encodings) # [B, T, S+1, F] + return output, source_lengths, target_lengths diff --git a/i6_models/assemblies/transducer/prediction_network.py b/i6_models/assemblies/transducer/prediction_network.py new file mode 100644 index 00000000..79c9ff95 --- /dev/null +++ b/i6_models/assemblies/transducer/prediction_network.py @@ -0,0 +1,219 @@ +__all__ = [ + "EmbeddingTransducerPredictionNetworkV1Config", + "EmbeddingTransducerPredictionNetworkV1", + "FfnnTransducerPredictionNetworkV1Config", + "FfnnTransducerPredictionNetworkV1", +] + +from dataclasses import dataclass, field +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from torch import nn + +from i6_models.config import ModelConfiguration +from i6_models.parts.ffnn import FeedForwardBlockV1Config, FeedForwardBlockV1 + + +@dataclass +class EmbeddingTransducerPredictionNetworkV1Config(ModelConfiguration): + """ + num_outputs: Number of output units (vocabulary size + blank). + blank_id: Index of the blank token. + context_history_size: Number of previous output tokens to consider as context + embedding_dim: Dimension of the embedding layer. + reduce_embedding: Whether to use a reduction mechanism for the context embedding. + num_reduction_heads: Number of reduction heads if reduce_embedding is True. + """ + + num_outputs: int + blank_id: int + context_history_size: int + embedding_dim: int + reduce_embedding: bool + num_reduction_heads: Optional[int] + + def __post__init__(self): + super().__post_init__() + assert (num_reduction_heads is not None) == reduce_embedding + + @classmethod + def from_child(cls, child_instance): + return cls( + child_instance.num_outputs, + child_instance.blank_id, + child_instance.context_history_size, + child_instance.embedding_dim, + child_instance.reduce_embedding, + child_instance.num_reduction_heads, + ) + + +class EmbeddingTransducerPredictionNetworkV1(nn.Module): + def __init__(self, cfg: EmbeddingTransducerPredictionNetworkV1Config) -> None: + super().__init__() + self.cfg = cfg + self.blank_id = self.cfg.blank_id + self.context_history_size = self.cfg.context_history_size + self.embedding = nn.Embedding( + num_embeddings=self.cfg.num_outputs, + embedding_dim=self.cfg.embedding_dim, + padding_idx=self.blank_id, + ) + self.output_dim = ( + self.cfg.embedding_dim * self.cfg.context_history_size + if not self.cfg.reduce_embedding + else self.cfg.embedding_dim + ) + + self.reduce_embedding = self.cfg.reduce_embedding + self.num_reduction_heads = self.cfg.num_reduction_heads + if self.reduce_embedding: + self.register_buffer( + "position_vectors", + torch.randn( + self.cfg.context_history_size, + self.cfg.num_reduction_heads, + self.cfg.embedding_dim, + ), + ) + + def _reduce_embedding(self, emb: torch.Tensor) -> torch.Tensor: + """ + Reduces the context embedding using a weighted sum based on position vectors. + """ + emb_expanded = emb.unsqueeze(3) # [B, S, H, 1, E] + pos_expanded = self.position_vectors.unsqueeze(0).unsqueeze(0) # [1, 1, H, K, E] + alpha = (emb_expanded * pos_expanded).sum(dim=-1, keepdim=True) # [B, S, H, K, 1] + weighted = alpha * emb_expanded # [B, S, H, K, E] + reduced = weighted.sum(dim=2).sum(dim=2) # [B, S, E] + reduced *= 1.0 / (self.cfg.num_reduction_heads * self.cfg.context_history_size) + return reduced + + def _forward_embedding(self, history: torch.Tensor) -> torch.Tensor: + """ + Processes the input history through the embedding layer and optional reduction. + """ + if len(history.shape) == 2: # reshape if input shape [B, H] + history = history.view(*history.shape[:-1], 1, history.shape[-1]) # [B, 1, H] + embed = self.embedding(history) # [B, S, H, E] + if self.reduce_embedding: + embed = self._reduce_embedding(embed) # [B, S, E] + else: + embed = embed.flatten(start_dim=-2) # [B, S, H*E] + return embed + + def forward( + self, + history: torch.Tensor, # [B, H] + ) -> torch.Tensor: # [B, 1, P] + """ + Forward pass for recognition mode. + """ + embed = self._forward_embedding(history) + return embed + + def forward_fullsum( + self, + targets: torch.Tensor, # [B, S] + target_lengths: torch.Tensor, # [B] + ) -> Tuple[torch.Tensor, torch.Tensor]: # [B, S + 1, P], [B] + """ + Forward pass for fullsum training. + """ + non_context_padding = torch.full( + (targets.size(0), self.cfg.context_history_size), + fill_value=self.blank_id, + dtype=targets.dtype, + device=targets.device, + ) # [B, H] + extended_targets = torch.cat([non_context_padding, targets], dim=1) # [B, S+H] + history = torch.stack( + [ + extended_targets[:, self.cfg.context_history_size - 1 - i : (-i if i != 0 else None)] + for i in reversed(range(self.cfg.context_history_size)) + ], + dim=-1, + ) # [B, S+1, H] + embed = self._forward_embedding(history) + + return embed, target_lengths + + def forward_viterbi( + self, + targets: torch.Tensor, # [B, T] + target_lengths: torch.Tensor, # [B] + ) -> Tuple[torch.Tensor, torch.Tensor]: # [B, T, P], [B] + """ + Forward pass for viterbi training. + """ + B, T = targets.shape + history = torch.zeros( + (B, T, self.cfg.context_history_size), + dtype=targets.dtype, + device=targets.device, + ) # [B, T, H] + recent_labels = torch.full( + (B, self.cfg.context_history_size), + fill_value=self.blank_id, + dtype=targets.dtype, + device=targets.device, + ) # [B, H] + + for t in range(T): + history[:, t, :] = recent_labels + current_labels = targets[:, t] + non_blank_positions = current_labels != self.blank_id + recent_labels[non_blank_positions, 1:] = recent_labels[non_blank_positions, :-1] + recent_labels[non_blank_positions, 0] = current_labels[non_blank_positions] + embed = self._forward_embedding(history) + + return embed, target_lengths + + +@dataclass +class FfnnTransducerPredictionNetworkV1Config(EmbeddingTransducerPredictionNetworkV1Config): + """ + Attributes: + ffnn_cfg: Configuration for FFNN prediction network + """ + + ffnn_cfg: FeedForwardBlockV1Config + + +class FfnnTransducerPredictionNetworkV1(EmbeddingTransducerPredictionNetworkV1): + """ + FfnnTransducerPredictionNetworkV1 with feedforward layers. + """ + + def __init__(self, cfg: FfnnTransducerPredictionNetworkV1Config): + super().__init__(EmbeddingTransducerPredictionNetworkV1Config.from_child(cfg)) + cfg.ffnn_cfg.input_dim = self.output_dim + self.ffnn = FeedForwardBlockV1(cfg.ffnn_cfg) + self.output_dim = self.ffnn.output_dim + + def forward( + self, + history: torch.Tensor, # [B, H] + ) -> torch.Tensor: # [B, 1, P] + embed = super().forward(history) + output = self.ffnn(embed) + return output + + def forward_fullsum( + self, + targets: torch.Tensor, # [B, S] + target_lengths: torch.Tensor, # [B] + ) -> Tuple[torch.Tensor, torch.Tensor]: # [B, S + 1, P], [B] + embed, _ = super().forward_fullsum(targets, target_lengths) + output = self.ffnn(embed) + return output, target_lengths + + def forward_viterbi( + self, + targets: torch.Tensor, # [B, T] + target_lengths: torch.Tensor, # [B] + ) -> Tuple[torch.Tensor, torch.Tensor]: # [B, T, P], [B] + embed, _ = super().forward_viterbi(targets, target_lengths) + output = self.ffnn(embed) + return output, target_lengths diff --git a/i6_models/parts/ffnn.py b/i6_models/parts/ffnn.py index 3bd30d38..5ecc91f5 100644 --- a/i6_models/parts/ffnn.py +++ b/i6_models/parts/ffnn.py @@ -1,8 +1,13 @@ -__all__ = ["FeedForwardConfig", "FeedForwardModel"] +__all__ = [ + "FeedForwardLayerV1Config", + "FeedForwardLayerV1", + "FeedForwardBlockV1Config", + "FeedForwardBlockV1", +] from dataclasses import dataclass from functools import partial -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union, List import torch from torch import nn @@ -57,3 +62,66 @@ def forward( tensor = self.activation(tensor) # [B,T,F] tensor = self.dropout(tensor) # [B,T,F] return tensor, sequence_mask + + +@dataclass +class FeedForwardBlockV1Config(ModelConfiguration): + """ + Configuration for the FeedForwardBlockV1 module. + + Attributes: + input_dim: Input feature dimension. + layer_sizes: List of hidden layer sizes. The length of this list + determines the number of layers. + dropouts: Dropout probability for each layer. + layer_activations: List of activation function applied after each linear layer. + None represents no activation. + Must have the same length as layer_sizes. + use_layer_norm: Whether to use Layer Normalization. + """ + + input_dim: int + layer_sizes: List[int] + dropouts: List[float] + layer_activations: List[Optional[Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]]]] + use_layer_norm: bool = True + + def __post_init__(self): + super().__post_init__() + assert all(0.0 <= dropout <= 1.0 for dropout in self.dropouts), "Dropout values must be probabilities" + assert len(self.layer_sizes) > 0, "layer_sizes must not be empty" + assert len(self.layer_sizes) == len(self.layer_activations) + assert len(self.layer_sizes) == len(self.dropouts) + + +class FeedForwardBlockV1(nn.Module): + """ + A multi-layer feed-forward network block with optional Layer Normalization. + """ + + def __init__(self, cfg: FeedForwardBlockV1Config): + super().__init__() + self.cfg = cfg + network_layers: List[nn.Module] = [] + prev_size = cfg.input_dim + + for i, layer_size in enumerate(cfg.layer_sizes): + if cfg.use_layer_norm: + network_layers.append(nn.LayerNorm(prev_size)) + network_layers.append(nn.Linear(prev_size, layer_size)) + prev_size = layer_size + if cfg.layer_activations[i] is not None: + network_layers.append(cfg.layer_activations[i]) + network_layers.append(nn.Dropout(cfg.dropouts[i])) + + self.output_dim = cfg.layer_sizes[-1] + self.network = nn.Sequential(*network_layers) + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + """ + Forward pass through the feed-forward block. + + :param tensor: Input tensor of shape [B, T, F], where F is input_dim. + :return: Output tensor of shape [B, T, output_dim]. + """ + return self.network(tensor)