-
Notifications
You must be signed in to change notification settings - Fork 0
Add transducer components #79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Perhaps merge master and the CI issues will go away. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Left some comments.
| @@ -1,8 +1,13 @@ | |||
| __all__ = ["FeedForwardConfig", "FeedForwardModel"] | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about these old items here? Why are they gone from __all__?
|
|
||
| input_dim: int | ||
| layer_sizes: List[int] | ||
| dropouts: List[float] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you customizing these values per layer? If not, consider allowing the simple one-value-fits-all variant as well like:
| dropouts: List[float] | |
| dropouts: Union[float, Sequence[float]] |
| """ | ||
|
|
||
| input_dim: int | ||
| layer_sizes: List[int] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I generally prefer
| layer_sizes: List[int] | |
| layer_sizes: Sequence[int] |
in config places like this as that also is correct for tuples, and I think tuples fit the configs better because they are immutable (but this is debatable).
| def __post_init__(self): | ||
| super().__post_init__() | ||
| assert all(0.0 <= dropout <= 1.0 for dropout in self.dropouts), "Dropout values must be probabilities" | ||
| assert len(self.layer_sizes) > 0, "layer_sizes must not be empty" | ||
| assert len(self.layer_sizes) == len(self.layer_activations) | ||
| assert len(self.layer_sizes) == len(self.dropouts) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
| network_layers.append(nn.Linear(prev_size, layer_size)) | ||
| prev_size = layer_size | ||
| if cfg.layer_activations[i] is not None: | ||
| network_layers.append(cfg.layer_activations[i]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a feeling we should be using ModuleFactorys instead, or at least add support for them, too.
| """ | ||
| Reduces the context embedding using a weighted sum based on position vectors. | ||
| """ | ||
| emb_expanded = emb.unsqueeze(3) # [B, S, H, 1, E] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider unsqueezing from the back.
| pos_expanded = self.position_vectors.unsqueeze(0).unsqueeze(0) # [1, 1, H, K, E] | ||
| alpha = (emb_expanded * pos_expanded).sum(dim=-1, keepdim=True) # [B, S, H, K, 1] | ||
| weighted = alpha * emb_expanded # [B, S, H, K, E] | ||
| reduced = weighted.sum(dim=2).sum(dim=2) # [B, S, E] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider indexing dims from the back.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dim can be a tuple of ints, so we could do it in one step
https://docs.pytorch.org/docs/stable/generated/torch.sum.html
| reduced = weighted.sum(dim=2).sum(dim=2) # [B, S, E] | |
| reduced = weighted.sum(dim=(-2, -1)) # [B, S, E] |
| """ | ||
|
|
||
| def __init__(self, cfg: FfnnTransducerPredictionNetworkV1Config): | ||
| super().__init__(EmbeddingTransducerPredictionNetworkV1Config.from_child(cfg)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the first config inherits from the second one, you are able to just:
| super().__init__(EmbeddingTransducerPredictionNetworkV1Config.from_child(cfg)) | |
| super().__init__(cfg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
EDIT: With composition instead of inheritance, this comment is no longer relevant.
| cfg.ffnn_cfg.input_dim = self.output_dim | ||
| self.ffnn = FeedForwardBlockV1(cfg.ffnn_cfg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leave the configs immutable. Always safer wrt. bugs.
| cfg.ffnn_cfg.input_dim = self.output_dim | |
| self.ffnn = FeedForwardBlockV1(cfg.ffnn_cfg) | |
| self.ffnn = FeedForwardBlockV1( | |
| dataclasses.replace( | |
| cfg, | |
| ffnn_cfg=dataclasses.replace(cfg.ffnn_cfg, input_dim=self.output_dim), | |
| ) | |
| ) |
This creates copies of the dataclasses as needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or we could not change anything and throw an error if a value is configured that is wrong..
| ffnn_cfg: FeedForwardBlockV1Config | ||
|
|
||
|
|
||
| class FfnnTransducerPredictionNetworkV1(EmbeddingTransducerPredictionNetworkV1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this class would benefit from using composition instead of inheritance. Make it contain/own an EmbeddingTransducerPredictionNetworkV1 instead of inheriting from one. That resolves all your issues wrt. config nesting/updating.
| if not self.training: | ||
| output = torch.log_softmax(output, dim=-1) # [B, T, S, F] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, I think we usually get logits in the train step and apply the appropriate softmax function there
| source_encodings: torch.Tensor, # [1, T, E] | ||
| target_encodings: torch.Tensor, # [B, S, P] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are source_encodings the output of the acoustic encoder and the target_encodings the output of the prediction network? Maybe we could rename (+ document) this better.
| pos_expanded = self.position_vectors.unsqueeze(0).unsqueeze(0) # [1, 1, H, K, E] | ||
| alpha = (emb_expanded * pos_expanded).sum(dim=-1, keepdim=True) # [B, S, H, K, 1] | ||
| weighted = alpha * emb_expanded # [B, S, H, K, E] | ||
| reduced = weighted.sum(dim=2).sum(dim=2) # [B, S, E] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dim can be a tuple of ints, so we could do it in one step
https://docs.pytorch.org/docs/stable/generated/torch.sum.html
| reduced = weighted.sum(dim=2).sum(dim=2) # [B, S, E] | |
| reduced = weighted.sum(dim=(-2, -1)) # [B, S, E] |
| Processes the input history through the embedding layer and optional reduction. | ||
| """ | ||
| if len(history.shape) == 2: # reshape if input shape [B, H] | ||
| history = history.view(*history.shape[:-1], 1, history.shape[-1]) # [B, 1, H] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
*history.shape[:-1] reads odd.. that should be the same as history.shape[0], since we have len(history.shape) == 2.. but talk to @NeoLegends about making this work with more batch dim.
| cfg.ffnn_cfg.input_dim = self.output_dim | ||
| self.ffnn = FeedForwardBlockV1(cfg.ffnn_cfg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or we could not change anything and throw an error if a value is configured that is wrong..
| def forward_fullsum( | ||
| self, | ||
| targets: torch.Tensor, # [B, S] | ||
| target_lengths: torch.Tensor, # [B] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the target_lengths seems to be unused in any of the forward calls.. is it needed?
| reduce_embedding: bool | ||
| num_reduction_heads: Optional[int] | ||
|
|
||
| def __post__init__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def __post__init__(self): | |
| def __post_init__(self): |
typo
| non_context_padding = torch.full( | ||
| (targets.size(0), self.cfg.context_history_size), | ||
| fill_value=self.blank_id, | ||
| dtype=targets.dtype, | ||
| device=targets.device, | ||
| ) # [B, H] | ||
| extended_targets = torch.cat([non_context_padding, targets], dim=1) # [B, S+H] | ||
| history = torch.stack( | ||
| [ | ||
| extended_targets[:, self.cfg.context_history_size - 1 - i : (-i if i != 0 else None)] | ||
| for i in reversed(range(self.cfg.context_history_size)) | ||
| ], | ||
| dim=-1, | ||
| ) # [B, S+1, H] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ChatGPT suggested this code
B, S = targets.shape
H = self.cfg.context_history_size
# Pad left with H blanks: [B, S+H]
extended = F.pad(targets, (H, 0), value=self.blank_id)
# Unfold over sequence dim to get [B, S+1, H]
# (PyTorch: unfold(size=H, step=1) "slides" a length-H window)
history = extended.unfold(dimension=1, size=H, step=1) # [B, S+1, H]
I add transducer prediction and joint network here. The advantage is that the interface supports three modes: recognition, fixed path training(viterbi), and fullsum training(standard RNN-T). Other NN structures like LSTM prediction network(I tested training but not recognition so it's not here) can also be expanded simply. I also supported embedding reduction as introduced in this paper, which gave a light improvement in my test. Many lines of code originate from Simon's setup.