Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions i6_models/assemblies/transducer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .prediction_network import *
from .joint_network import *
81 changes: 81 additions & 0 deletions i6_models/assemblies/transducer/joint_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
__all__ = ["TransducerJointNetworkV1Config", "TransducerJointNetworkV1"]

from dataclasses import dataclass
from typing import Any, Dict, Tuple, Union

import torch
from torch import nn

from i6_models.config import ModelConfiguration
from i6_models.parts.ffnn import FeedForwardBlockV1Config, FeedForwardBlockV1


@dataclass
class TransducerJointNetworkV1Config(ModelConfiguration):
"""
Configuration for the Transducer Joint Network.
Attributes:
ffnn_cfg: Configuration for the internal feed-forward network.
"""

ffnn_cfg: FeedForwardBlockV1Config


class TransducerJointNetworkV1(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs docs.

def __init__(
self,
cfg: TransducerJointNetworkV1Config,
) -> None:
Comment on lines +26 to +29
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(
self,
cfg: TransducerJointNetworkV1Config,
) -> None:
def __init__(self, cfg: TransducerJointNetworkV1Config):

super().__init__()
self.ffnn = FeedForwardBlockV1(cfg.ffnn_cfg)
self.output_dim = self.ffnn.output_dim

def forward(
self,
source_encodings: torch.Tensor, # [1, T, E]
target_encodings: torch.Tensor, # [B, S, P]
Comment on lines +36 to +37
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are source_encodings the output of the acoustic encoder and the target_encodings the output of the prediction network? Maybe we could rename (+ document) this better.

) -> torch.Tensor: # [B, T, S, F]
"""
Forward pass for recognition.
"""
combined_encodings = source_encodings.unsqueeze(2) + target_encodings.unsqueeze(1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment wrt. dim indices from the rear.

output = self.ffnn(combined_encodings) # [B, T, S, F]

if not self.training:
output = torch.log_softmax(output, dim=-1) # [B, T, S, F]
Comment on lines +45 to +46
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I'm a big fan of switching between logits and log probs based on whether it's train time or not. I'd rather pass a parameter or leave the log softmax to the forward_step function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, I think we usually get logits in the train step and apply the appropriate softmax function there

return output

def forward_viterbi(
self,
source_encodings: torch.Tensor, # [B, T, E]
source_lengths: torch.Tensor, # [B]
target_encodings: torch.Tensor, # [B, T, P]
target_lengths: torch.Tensor, # [B]
) -> torch.Tensor: # [B, T, F]
"""
Forward pass for Viterbi training.
"""
combined_encodings = source_encodings + target_encodings
output = self.ffnn(combined_encodings) # [B, T, F]
if not self.training:
output = torch.log_softmax(output, dim=-1) # [B, T, F]
Comment on lines +61 to +62
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above.

return output, source_lengths, target_lengths

def forward_fullsum(
self,
source_encodings: torch.Tensor, # [B, T, E]
source_lengths: torch.Tensor, # [B]
target_encodings: torch.Tensor, # [B, S+1, P]
target_lengths: torch.Tensor, # [B]
) -> torch.Tensor: # [B, T, S+1, F]
"""
Forward pass for fullsum training. Returns output with shape [B, T, S+1, F].
"""

# additive combination
combined_encodings = source_encodings.unsqueeze(2) + target_encodings.unsqueeze(1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider unsqueezing from the back (i.e. use negative dim indices like .unsqueeze(-X) to be more generic across varying numbers of (batch) dims. Just in case this comes in handy at any time in the future.


# Pass through FFNN
output = self.ffnn(combined_encodings) # [B, T, S+1, F]
return output, source_lengths, target_lengths
219 changes: 219 additions & 0 deletions i6_models/assemblies/transducer/prediction_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
__all__ = [
"EmbeddingTransducerPredictionNetworkV1Config",
"EmbeddingTransducerPredictionNetworkV1",
"FfnnTransducerPredictionNetworkV1Config",
"FfnnTransducerPredictionNetworkV1",
]

from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple, Union

import torch
from torch import nn

from i6_models.config import ModelConfiguration
from i6_models.parts.ffnn import FeedForwardBlockV1Config, FeedForwardBlockV1


@dataclass
class EmbeddingTransducerPredictionNetworkV1Config(ModelConfiguration):
"""
num_outputs: Number of output units (vocabulary size + blank).
blank_id: Index of the blank token.
context_history_size: Number of previous output tokens to consider as context
embedding_dim: Dimension of the embedding layer.
reduce_embedding: Whether to use a reduction mechanism for the context embedding.
num_reduction_heads: Number of reduction heads if reduce_embedding is True.
"""

num_outputs: int
blank_id: int
context_history_size: int
embedding_dim: int
reduce_embedding: bool
num_reduction_heads: Optional[int]

def __post__init__(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __post__init__(self):
def __post_init__(self):

typo

super().__post_init__()
assert (num_reduction_heads is not None) == reduce_embedding

@classmethod
def from_child(cls, child_instance):
return cls(
child_instance.num_outputs,
child_instance.blank_id,
child_instance.context_history_size,
child_instance.embedding_dim,
child_instance.reduce_embedding,
child_instance.num_reduction_heads,
)
Comment on lines +41 to +49
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this do, why is it necessary and different from

config2 = copy.deepcopy(config1)

?



class EmbeddingTransducerPredictionNetworkV1(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs docs.

def __init__(self, cfg: EmbeddingTransducerPredictionNetworkV1Config) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(self, cfg: EmbeddingTransducerPredictionNetworkV1Config) -> None:
def __init__(self, cfg: EmbeddingTransducerPredictionNetworkV1Config):

super().__init__()
self.cfg = cfg
self.blank_id = self.cfg.blank_id
self.context_history_size = self.cfg.context_history_size
self.embedding = nn.Embedding(
num_embeddings=self.cfg.num_outputs,
embedding_dim=self.cfg.embedding_dim,
padding_idx=self.blank_id,
)
self.output_dim = (
self.cfg.embedding_dim * self.cfg.context_history_size
if not self.cfg.reduce_embedding
else self.cfg.embedding_dim
)

self.reduce_embedding = self.cfg.reduce_embedding
self.num_reduction_heads = self.cfg.num_reduction_heads
if self.reduce_embedding:
self.register_buffer(
"position_vectors",
torch.randn(
self.cfg.context_history_size,
self.cfg.num_reduction_heads,
self.cfg.embedding_dim,
),
)

def _reduce_embedding(self, emb: torch.Tensor) -> torch.Tensor:
"""
Reduces the context embedding using a weighted sum based on position vectors.
"""
emb_expanded = emb.unsqueeze(3) # [B, S, H, 1, E]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider unsqueezing from the back.

pos_expanded = self.position_vectors.unsqueeze(0).unsqueeze(0) # [1, 1, H, K, E]
alpha = (emb_expanded * pos_expanded).sum(dim=-1, keepdim=True) # [B, S, H, K, 1]
weighted = alpha * emb_expanded # [B, S, H, K, E]
reduced = weighted.sum(dim=2).sum(dim=2) # [B, S, E]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider indexing dims from the back.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dim can be a tuple of ints, so we could do it in one step
https://docs.pytorch.org/docs/stable/generated/torch.sum.html

Suggested change
reduced = weighted.sum(dim=2).sum(dim=2) # [B, S, E]
reduced = weighted.sum(dim=(-2, -1)) # [B, S, E]

reduced *= 1.0 / (self.cfg.num_reduction_heads * self.cfg.context_history_size)
return reduced

def _forward_embedding(self, history: torch.Tensor) -> torch.Tensor:
"""
Processes the input history through the embedding layer and optional reduction.
"""
if len(history.shape) == 2: # reshape if input shape [B, H]
history = history.view(*history.shape[:-1], 1, history.shape[-1]) # [B, 1, H]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*history.shape[:-1] reads odd.. that should be the same as history.shape[0], since we have len(history.shape) == 2.. but talk to @NeoLegends about making this work with more batch dim.

embed = self.embedding(history) # [B, S, H, E]
if self.reduce_embedding:
embed = self._reduce_embedding(embed) # [B, S, E]
else:
embed = embed.flatten(start_dim=-2) # [B, S, H*E]
return embed

def forward(
self,
history: torch.Tensor, # [B, H]
) -> torch.Tensor: # [B, 1, P]
"""
Forward pass for recognition mode.
"""
embed = self._forward_embedding(history)
return embed

def forward_fullsum(
self,
targets: torch.Tensor, # [B, S]
target_lengths: torch.Tensor, # [B]
) -> Tuple[torch.Tensor, torch.Tensor]: # [B, S + 1, P], [B]
"""
Forward pass for fullsum training.
"""
non_context_padding = torch.full(
(targets.size(0), self.cfg.context_history_size),
fill_value=self.blank_id,
dtype=targets.dtype,
device=targets.device,
) # [B, H]
extended_targets = torch.cat([non_context_padding, targets], dim=1) # [B, S+H]
history = torch.stack(
[
extended_targets[:, self.cfg.context_history_size - 1 - i : (-i if i != 0 else None)]
for i in reversed(range(self.cfg.context_history_size))
],
dim=-1,
) # [B, S+1, H]
Comment on lines +124 to +137
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ChatGPT suggested this code

        B, S = targets.shape
        H = self.cfg.context_history_size

        # Pad left with H blanks: [B, S+H]
        extended = F.pad(targets, (H, 0), value=self.blank_id)

        # Unfold over sequence dim to get [B, S+1, H]
        # (PyTorch: unfold(size=H, step=1) "slides" a length-H window)
        history = extended.unfold(dimension=1, size=H, step=1)  # [B, S+1, H]

embed = self._forward_embedding(history)

return embed, target_lengths

def forward_viterbi(
self,
targets: torch.Tensor, # [B, T]
target_lengths: torch.Tensor, # [B]
) -> Tuple[torch.Tensor, torch.Tensor]: # [B, T, P], [B]
"""
Forward pass for viterbi training.
"""
B, T = targets.shape
history = torch.zeros(
(B, T, self.cfg.context_history_size),
dtype=targets.dtype,
device=targets.device,
) # [B, T, H]
recent_labels = torch.full(
(B, self.cfg.context_history_size),
fill_value=self.blank_id,
dtype=targets.dtype,
device=targets.device,
) # [B, H]

for t in range(T):
history[:, t, :] = recent_labels
current_labels = targets[:, t]
non_blank_positions = current_labels != self.blank_id
recent_labels[non_blank_positions, 1:] = recent_labels[non_blank_positions, :-1]
recent_labels[non_blank_positions, 0] = current_labels[non_blank_positions]
embed = self._forward_embedding(history)

return embed, target_lengths


@dataclass
class FfnnTransducerPredictionNetworkV1Config(EmbeddingTransducerPredictionNetworkV1Config):
"""
Attributes:
ffnn_cfg: Configuration for FFNN prediction network
"""

ffnn_cfg: FeedForwardBlockV1Config


class FfnnTransducerPredictionNetworkV1(EmbeddingTransducerPredictionNetworkV1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this class would benefit from using composition instead of inheritance. Make it contain/own an EmbeddingTransducerPredictionNetworkV1 instead of inheriting from one. That resolves all your issues wrt. config nesting/updating.

"""
FfnnTransducerPredictionNetworkV1 with feedforward layers.
"""

def __init__(self, cfg: FfnnTransducerPredictionNetworkV1Config):
super().__init__(EmbeddingTransducerPredictionNetworkV1Config.from_child(cfg))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the first config inherits from the second one, you are able to just:

Suggested change
super().__init__(EmbeddingTransducerPredictionNetworkV1Config.from_child(cfg))
super().__init__(cfg)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EDIT: With composition instead of inheritance, this comment is no longer relevant.

cfg.ffnn_cfg.input_dim = self.output_dim
self.ffnn = FeedForwardBlockV1(cfg.ffnn_cfg)
Comment on lines +191 to +192
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave the configs immutable. Always safer wrt. bugs.

Suggested change
cfg.ffnn_cfg.input_dim = self.output_dim
self.ffnn = FeedForwardBlockV1(cfg.ffnn_cfg)
self.ffnn = FeedForwardBlockV1(
dataclasses.replace(
cfg,
ffnn_cfg=dataclasses.replace(cfg.ffnn_cfg, input_dim=self.output_dim),
)
)

This creates copies of the dataclasses as needed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or we could not change anything and throw an error if a value is configured that is wrong..

self.output_dim = self.ffnn.output_dim

def forward(
self,
history: torch.Tensor, # [B, H]
) -> torch.Tensor: # [B, 1, P]
embed = super().forward(history)
output = self.ffnn(embed)
return output

def forward_fullsum(
self,
targets: torch.Tensor, # [B, S]
target_lengths: torch.Tensor, # [B]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the target_lengths seems to be unused in any of the forward calls.. is it needed?

) -> Tuple[torch.Tensor, torch.Tensor]: # [B, S + 1, P], [B]
embed, _ = super().forward_fullsum(targets, target_lengths)
output = self.ffnn(embed)
return output, target_lengths

def forward_viterbi(
self,
targets: torch.Tensor, # [B, T]
target_lengths: torch.Tensor, # [B]
) -> Tuple[torch.Tensor, torch.Tensor]: # [B, T, P], [B]
embed, _ = super().forward_viterbi(targets, target_lengths)
output = self.ffnn(embed)
return output, target_lengths
72 changes: 70 additions & 2 deletions i6_models/parts/ffnn.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
__all__ = ["FeedForwardConfig", "FeedForwardModel"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about these old items here? Why are they gone from __all__?

__all__ = [
"FeedForwardLayerV1Config",
"FeedForwardLayerV1",
"FeedForwardBlockV1Config",
"FeedForwardBlockV1",
]

from dataclasses import dataclass
from functools import partial
from typing import Callable, Optional, Tuple, Union
from typing import Callable, Optional, Tuple, Union, List

import torch
from torch import nn
Expand Down Expand Up @@ -57,3 +62,66 @@ def forward(
tensor = self.activation(tensor) # [B,T,F]
tensor = self.dropout(tensor) # [B,T,F]
return tensor, sequence_mask


@dataclass
class FeedForwardBlockV1Config(ModelConfiguration):
"""
Configuration for the FeedForwardBlockV1 module.

Attributes:
input_dim: Input feature dimension.
layer_sizes: List of hidden layer sizes. The length of this list
determines the number of layers.
dropouts: Dropout probability for each layer.
layer_activations: List of activation function applied after each linear layer.
None represents no activation.
Must have the same length as layer_sizes.
use_layer_norm: Whether to use Layer Normalization.
"""

input_dim: int
layer_sizes: List[int]
Copy link
Member

@NeoLegends NeoLegends Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I generally prefer

Suggested change
layer_sizes: List[int]
layer_sizes: Sequence[int]

in config places like this as that also is correct for tuples, and I think tuples fit the configs better because they are immutable (but this is debatable).

dropouts: List[float]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you customizing these values per layer? If not, consider allowing the simple one-value-fits-all variant as well like:

Suggested change
dropouts: List[float]
dropouts: Union[float, Sequence[float]]

layer_activations: List[Optional[Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]]]]
use_layer_norm: bool = True

def __post_init__(self):
super().__post_init__()
assert all(0.0 <= dropout <= 1.0 for dropout in self.dropouts), "Dropout values must be probabilities"
assert len(self.layer_sizes) > 0, "layer_sizes must not be empty"
assert len(self.layer_sizes) == len(self.layer_activations)
assert len(self.layer_sizes) == len(self.dropouts)
Comment on lines +89 to +94
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!



class FeedForwardBlockV1(nn.Module):
"""
A multi-layer feed-forward network block with optional Layer Normalization.
"""

def __init__(self, cfg: FeedForwardBlockV1Config):
super().__init__()
self.cfg = cfg
network_layers: List[nn.Module] = []
prev_size = cfg.input_dim

for i, layer_size in enumerate(cfg.layer_sizes):
if cfg.use_layer_norm:
network_layers.append(nn.LayerNorm(prev_size))
network_layers.append(nn.Linear(prev_size, layer_size))
prev_size = layer_size
if cfg.layer_activations[i] is not None:
network_layers.append(cfg.layer_activations[i])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a feeling we should be using ModuleFactorys instead, or at least add support for them, too.

network_layers.append(nn.Dropout(cfg.dropouts[i]))

self.output_dim = cfg.layer_sizes[-1]
self.network = nn.Sequential(*network_layers)

def forward(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the feed-forward block.

:param tensor: Input tensor of shape [B, T, F], where F is input_dim.
:return: Output tensor of shape [B, T, output_dim].
"""
return self.network(tensor)