Skip to content

Conversation

@Stefanwuu
Copy link
Contributor

I add transducer prediction and joint network here. The advantage is that the interface supports three modes: recognition, fixed path training(viterbi), and fullsum training(standard RNN-T). Other NN structures like LSTM prediction network(I tested training but not recognition so it's not here) can also be expanded simply. I also supported embedding reduction as introduced in this paper, which gave a light improvement in my test. Many lines of code originate from Simon's setup.

@NeoLegends
Copy link
Member

Perhaps merge master and the CI issues will go away.

Copy link
Member

@NeoLegends NeoLegends left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Left some comments.

@@ -1,8 +1,13 @@
__all__ = ["FeedForwardConfig", "FeedForwardModel"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about these old items here? Why are they gone from __all__?


input_dim: int
layer_sizes: List[int]
dropouts: List[float]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you customizing these values per layer? If not, consider allowing the simple one-value-fits-all variant as well like:

Suggested change
dropouts: List[float]
dropouts: Union[float, Sequence[float]]

"""

input_dim: int
layer_sizes: List[int]
Copy link
Member

@NeoLegends NeoLegends Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I generally prefer

Suggested change
layer_sizes: List[int]
layer_sizes: Sequence[int]

in config places like this as that also is correct for tuples, and I think tuples fit the configs better because they are immutable (but this is debatable).

Comment on lines +89 to +94
def __post_init__(self):
super().__post_init__()
assert all(0.0 <= dropout <= 1.0 for dropout in self.dropouts), "Dropout values must be probabilities"
assert len(self.layer_sizes) > 0, "layer_sizes must not be empty"
assert len(self.layer_sizes) == len(self.layer_activations)
assert len(self.layer_sizes) == len(self.dropouts)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

network_layers.append(nn.Linear(prev_size, layer_size))
prev_size = layer_size
if cfg.layer_activations[i] is not None:
network_layers.append(cfg.layer_activations[i])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a feeling we should be using ModuleFactorys instead, or at least add support for them, too.

"""
Reduces the context embedding using a weighted sum based on position vectors.
"""
emb_expanded = emb.unsqueeze(3) # [B, S, H, 1, E]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider unsqueezing from the back.

pos_expanded = self.position_vectors.unsqueeze(0).unsqueeze(0) # [1, 1, H, K, E]
alpha = (emb_expanded * pos_expanded).sum(dim=-1, keepdim=True) # [B, S, H, K, 1]
weighted = alpha * emb_expanded # [B, S, H, K, E]
reduced = weighted.sum(dim=2).sum(dim=2) # [B, S, E]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider indexing dims from the back.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dim can be a tuple of ints, so we could do it in one step
https://docs.pytorch.org/docs/stable/generated/torch.sum.html

Suggested change
reduced = weighted.sum(dim=2).sum(dim=2) # [B, S, E]
reduced = weighted.sum(dim=(-2, -1)) # [B, S, E]

"""

def __init__(self, cfg: FfnnTransducerPredictionNetworkV1Config):
super().__init__(EmbeddingTransducerPredictionNetworkV1Config.from_child(cfg))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the first config inherits from the second one, you are able to just:

Suggested change
super().__init__(EmbeddingTransducerPredictionNetworkV1Config.from_child(cfg))
super().__init__(cfg)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EDIT: With composition instead of inheritance, this comment is no longer relevant.

Comment on lines +191 to +192
cfg.ffnn_cfg.input_dim = self.output_dim
self.ffnn = FeedForwardBlockV1(cfg.ffnn_cfg)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave the configs immutable. Always safer wrt. bugs.

Suggested change
cfg.ffnn_cfg.input_dim = self.output_dim
self.ffnn = FeedForwardBlockV1(cfg.ffnn_cfg)
self.ffnn = FeedForwardBlockV1(
dataclasses.replace(
cfg,
ffnn_cfg=dataclasses.replace(cfg.ffnn_cfg, input_dim=self.output_dim),
)
)

This creates copies of the dataclasses as needed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or we could not change anything and throw an error if a value is configured that is wrong..

ffnn_cfg: FeedForwardBlockV1Config


class FfnnTransducerPredictionNetworkV1(EmbeddingTransducerPredictionNetworkV1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this class would benefit from using composition instead of inheritance. Make it contain/own an EmbeddingTransducerPredictionNetworkV1 instead of inheriting from one. That resolves all your issues wrt. config nesting/updating.

Comment on lines +45 to +46
if not self.training:
output = torch.log_softmax(output, dim=-1) # [B, T, S, F]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, I think we usually get logits in the train step and apply the appropriate softmax function there

Comment on lines +36 to +37
source_encodings: torch.Tensor, # [1, T, E]
target_encodings: torch.Tensor, # [B, S, P]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are source_encodings the output of the acoustic encoder and the target_encodings the output of the prediction network? Maybe we could rename (+ document) this better.

pos_expanded = self.position_vectors.unsqueeze(0).unsqueeze(0) # [1, 1, H, K, E]
alpha = (emb_expanded * pos_expanded).sum(dim=-1, keepdim=True) # [B, S, H, K, 1]
weighted = alpha * emb_expanded # [B, S, H, K, E]
reduced = weighted.sum(dim=2).sum(dim=2) # [B, S, E]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dim can be a tuple of ints, so we could do it in one step
https://docs.pytorch.org/docs/stable/generated/torch.sum.html

Suggested change
reduced = weighted.sum(dim=2).sum(dim=2) # [B, S, E]
reduced = weighted.sum(dim=(-2, -1)) # [B, S, E]

Processes the input history through the embedding layer and optional reduction.
"""
if len(history.shape) == 2: # reshape if input shape [B, H]
history = history.view(*history.shape[:-1], 1, history.shape[-1]) # [B, 1, H]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*history.shape[:-1] reads odd.. that should be the same as history.shape[0], since we have len(history.shape) == 2.. but talk to @NeoLegends about making this work with more batch dim.

Comment on lines +191 to +192
cfg.ffnn_cfg.input_dim = self.output_dim
self.ffnn = FeedForwardBlockV1(cfg.ffnn_cfg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or we could not change anything and throw an error if a value is configured that is wrong..

def forward_fullsum(
self,
targets: torch.Tensor, # [B, S]
target_lengths: torch.Tensor, # [B]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the target_lengths seems to be unused in any of the forward calls.. is it needed?

reduce_embedding: bool
num_reduction_heads: Optional[int]

def __post__init__(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __post__init__(self):
def __post_init__(self):

typo

Comment on lines +124 to +137
non_context_padding = torch.full(
(targets.size(0), self.cfg.context_history_size),
fill_value=self.blank_id,
dtype=targets.dtype,
device=targets.device,
) # [B, H]
extended_targets = torch.cat([non_context_padding, targets], dim=1) # [B, S+H]
history = torch.stack(
[
extended_targets[:, self.cfg.context_history_size - 1 - i : (-i if i != 0 else None)]
for i in reversed(range(self.cfg.context_history_size))
],
dim=-1,
) # [B, S+1, H]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ChatGPT suggested this code

        B, S = targets.shape
        H = self.cfg.context_history_size

        # Pad left with H blanks: [B, S+H]
        extended = F.pad(targets, (H, 0), value=self.blank_id)

        # Unfold over sequence dim to get [B, S+1, H]
        # (PyTorch: unfold(size=H, step=1) "slides" a length-H window)
        history = extended.unfold(dimension=1, size=H, step=1)  # [B, S+1, H]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants