-
Notifications
You must be signed in to change notification settings - Fork 0
Add transducer components #79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e2c47c0
745f6d7
02b4573
5620c38
806be26
69444ff
e73180f
a7297c7
6ce4f64
7385d1f
e707f5f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| from .prediction_network import * | ||
| from .joint_network import * |
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,81 @@ | ||||||||||||
| __all__ = ["TransducerJointNetworkV1Config", "TransducerJointNetworkV1"] | ||||||||||||
|
|
||||||||||||
| from dataclasses import dataclass | ||||||||||||
| from typing import Any, Dict, Tuple, Union | ||||||||||||
|
|
||||||||||||
| import torch | ||||||||||||
| from torch import nn | ||||||||||||
|
|
||||||||||||
| from i6_models.config import ModelConfiguration | ||||||||||||
| from i6_models.parts.ffnn import FeedForwardBlockV1Config, FeedForwardBlockV1 | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @dataclass | ||||||||||||
| class TransducerJointNetworkV1Config(ModelConfiguration): | ||||||||||||
| """ | ||||||||||||
| Configuration for the Transducer Joint Network. | ||||||||||||
| Attributes: | ||||||||||||
| ffnn_cfg: Configuration for the internal feed-forward network. | ||||||||||||
| """ | ||||||||||||
|
|
||||||||||||
| ffnn_cfg: FeedForwardBlockV1Config | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| class TransducerJointNetworkV1(nn.Module): | ||||||||||||
| def __init__( | ||||||||||||
| self, | ||||||||||||
| cfg: TransducerJointNetworkV1Config, | ||||||||||||
| ) -> None: | ||||||||||||
|
Comment on lines
+26
to
+29
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
| super().__init__() | ||||||||||||
| self.ffnn = FeedForwardBlockV1(cfg.ffnn_cfg) | ||||||||||||
| self.output_dim = self.ffnn.output_dim | ||||||||||||
|
|
||||||||||||
| def forward( | ||||||||||||
| self, | ||||||||||||
| source_encodings: torch.Tensor, # [1, T, E] | ||||||||||||
| target_encodings: torch.Tensor, # [B, S, P] | ||||||||||||
|
Comment on lines
+36
to
+37
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are source_encodings the output of the acoustic encoder and the target_encodings the output of the prediction network? Maybe we could rename (+ document) this better. |
||||||||||||
| ) -> torch.Tensor: # [B, T, S, F] | ||||||||||||
| """ | ||||||||||||
| Forward pass for recognition. | ||||||||||||
| """ | ||||||||||||
| combined_encodings = source_encodings.unsqueeze(2) + target_encodings.unsqueeze(1) | ||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See comment wrt. dim indices from the rear. |
||||||||||||
| output = self.ffnn(combined_encodings) # [B, T, S, F] | ||||||||||||
|
|
||||||||||||
| if not self.training: | ||||||||||||
| output = torch.log_softmax(output, dim=-1) # [B, T, S, F] | ||||||||||||
|
Comment on lines
+45
to
+46
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think I'm a big fan of switching between logits and log probs based on whether it's train time or not. I'd rather pass a parameter or leave the log softmax to the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1, I think we usually get logits in the train step and apply the appropriate softmax function there |
||||||||||||
| return output | ||||||||||||
|
|
||||||||||||
| def forward_viterbi( | ||||||||||||
| self, | ||||||||||||
| source_encodings: torch.Tensor, # [B, T, E] | ||||||||||||
| source_lengths: torch.Tensor, # [B] | ||||||||||||
| target_encodings: torch.Tensor, # [B, T, P] | ||||||||||||
| target_lengths: torch.Tensor, # [B] | ||||||||||||
| ) -> torch.Tensor: # [B, T, F] | ||||||||||||
| """ | ||||||||||||
| Forward pass for Viterbi training. | ||||||||||||
| """ | ||||||||||||
| combined_encodings = source_encodings + target_encodings | ||||||||||||
| output = self.ffnn(combined_encodings) # [B, T, F] | ||||||||||||
| if not self.training: | ||||||||||||
| output = torch.log_softmax(output, dim=-1) # [B, T, F] | ||||||||||||
|
Comment on lines
+61
to
+62
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See above. |
||||||||||||
| return output, source_lengths, target_lengths | ||||||||||||
|
|
||||||||||||
| def forward_fullsum( | ||||||||||||
| self, | ||||||||||||
| source_encodings: torch.Tensor, # [B, T, E] | ||||||||||||
| source_lengths: torch.Tensor, # [B] | ||||||||||||
| target_encodings: torch.Tensor, # [B, S+1, P] | ||||||||||||
| target_lengths: torch.Tensor, # [B] | ||||||||||||
| ) -> torch.Tensor: # [B, T, S+1, F] | ||||||||||||
| """ | ||||||||||||
| Forward pass for fullsum training. Returns output with shape [B, T, S+1, F]. | ||||||||||||
| """ | ||||||||||||
|
|
||||||||||||
| # additive combination | ||||||||||||
| combined_encodings = source_encodings.unsqueeze(2) + target_encodings.unsqueeze(1) | ||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider unsqueezing from the back (i.e. use negative dim indices like |
||||||||||||
|
|
||||||||||||
| # Pass through FFNN | ||||||||||||
| output = self.ffnn(combined_encodings) # [B, T, S+1, F] | ||||||||||||
| return output, source_lengths, target_lengths | ||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,219 @@ | ||||||||||||||||||
| __all__ = [ | ||||||||||||||||||
| "EmbeddingTransducerPredictionNetworkV1Config", | ||||||||||||||||||
| "EmbeddingTransducerPredictionNetworkV1", | ||||||||||||||||||
| "FfnnTransducerPredictionNetworkV1Config", | ||||||||||||||||||
| "FfnnTransducerPredictionNetworkV1", | ||||||||||||||||||
| ] | ||||||||||||||||||
|
|
||||||||||||||||||
| from dataclasses import dataclass, field | ||||||||||||||||||
| from typing import Any, Dict, Optional, Tuple, Union | ||||||||||||||||||
|
|
||||||||||||||||||
| import torch | ||||||||||||||||||
| from torch import nn | ||||||||||||||||||
|
|
||||||||||||||||||
| from i6_models.config import ModelConfiguration | ||||||||||||||||||
| from i6_models.parts.ffnn import FeedForwardBlockV1Config, FeedForwardBlockV1 | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| @dataclass | ||||||||||||||||||
| class EmbeddingTransducerPredictionNetworkV1Config(ModelConfiguration): | ||||||||||||||||||
| """ | ||||||||||||||||||
| num_outputs: Number of output units (vocabulary size + blank). | ||||||||||||||||||
| blank_id: Index of the blank token. | ||||||||||||||||||
| context_history_size: Number of previous output tokens to consider as context | ||||||||||||||||||
| embedding_dim: Dimension of the embedding layer. | ||||||||||||||||||
| reduce_embedding: Whether to use a reduction mechanism for the context embedding. | ||||||||||||||||||
| num_reduction_heads: Number of reduction heads if reduce_embedding is True. | ||||||||||||||||||
| """ | ||||||||||||||||||
|
|
||||||||||||||||||
| num_outputs: int | ||||||||||||||||||
| blank_id: int | ||||||||||||||||||
| context_history_size: int | ||||||||||||||||||
| embedding_dim: int | ||||||||||||||||||
| reduce_embedding: bool | ||||||||||||||||||
| num_reduction_heads: Optional[int] | ||||||||||||||||||
|
|
||||||||||||||||||
| def __post__init__(self): | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
typo |
||||||||||||||||||
| super().__post_init__() | ||||||||||||||||||
| assert (num_reduction_heads is not None) == reduce_embedding | ||||||||||||||||||
|
|
||||||||||||||||||
| @classmethod | ||||||||||||||||||
| def from_child(cls, child_instance): | ||||||||||||||||||
| return cls( | ||||||||||||||||||
| child_instance.num_outputs, | ||||||||||||||||||
| child_instance.blank_id, | ||||||||||||||||||
| child_instance.context_history_size, | ||||||||||||||||||
| child_instance.embedding_dim, | ||||||||||||||||||
| child_instance.reduce_embedding, | ||||||||||||||||||
| child_instance.num_reduction_heads, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
Comment on lines
+41
to
+49
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does this do, why is it necessary and different from config2 = copy.deepcopy(config1)? |
||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| class EmbeddingTransducerPredictionNetworkV1(nn.Module): | ||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Needs docs. |
||||||||||||||||||
| def __init__(self, cfg: EmbeddingTransducerPredictionNetworkV1Config) -> None: | ||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||
| super().__init__() | ||||||||||||||||||
| self.cfg = cfg | ||||||||||||||||||
| self.blank_id = self.cfg.blank_id | ||||||||||||||||||
| self.context_history_size = self.cfg.context_history_size | ||||||||||||||||||
| self.embedding = nn.Embedding( | ||||||||||||||||||
| num_embeddings=self.cfg.num_outputs, | ||||||||||||||||||
| embedding_dim=self.cfg.embedding_dim, | ||||||||||||||||||
| padding_idx=self.blank_id, | ||||||||||||||||||
| ) | ||||||||||||||||||
| self.output_dim = ( | ||||||||||||||||||
| self.cfg.embedding_dim * self.cfg.context_history_size | ||||||||||||||||||
| if not self.cfg.reduce_embedding | ||||||||||||||||||
| else self.cfg.embedding_dim | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| self.reduce_embedding = self.cfg.reduce_embedding | ||||||||||||||||||
| self.num_reduction_heads = self.cfg.num_reduction_heads | ||||||||||||||||||
| if self.reduce_embedding: | ||||||||||||||||||
| self.register_buffer( | ||||||||||||||||||
| "position_vectors", | ||||||||||||||||||
| torch.randn( | ||||||||||||||||||
| self.cfg.context_history_size, | ||||||||||||||||||
| self.cfg.num_reduction_heads, | ||||||||||||||||||
| self.cfg.embedding_dim, | ||||||||||||||||||
| ), | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| def _reduce_embedding(self, emb: torch.Tensor) -> torch.Tensor: | ||||||||||||||||||
| """ | ||||||||||||||||||
| Reduces the context embedding using a weighted sum based on position vectors. | ||||||||||||||||||
| """ | ||||||||||||||||||
| emb_expanded = emb.unsqueeze(3) # [B, S, H, 1, E] | ||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider unsqueezing from the back. |
||||||||||||||||||
| pos_expanded = self.position_vectors.unsqueeze(0).unsqueeze(0) # [1, 1, H, K, E] | ||||||||||||||||||
| alpha = (emb_expanded * pos_expanded).sum(dim=-1, keepdim=True) # [B, S, H, K, 1] | ||||||||||||||||||
| weighted = alpha * emb_expanded # [B, S, H, K, E] | ||||||||||||||||||
| reduced = weighted.sum(dim=2).sum(dim=2) # [B, S, E] | ||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider indexing dims from the back.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||
| reduced *= 1.0 / (self.cfg.num_reduction_heads * self.cfg.context_history_size) | ||||||||||||||||||
| return reduced | ||||||||||||||||||
|
|
||||||||||||||||||
| def _forward_embedding(self, history: torch.Tensor) -> torch.Tensor: | ||||||||||||||||||
| """ | ||||||||||||||||||
| Processes the input history through the embedding layer and optional reduction. | ||||||||||||||||||
| """ | ||||||||||||||||||
| if len(history.shape) == 2: # reshape if input shape [B, H] | ||||||||||||||||||
| history = history.view(*history.shape[:-1], 1, history.shape[-1]) # [B, 1, H] | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||||||||||
| embed = self.embedding(history) # [B, S, H, E] | ||||||||||||||||||
| if self.reduce_embedding: | ||||||||||||||||||
| embed = self._reduce_embedding(embed) # [B, S, E] | ||||||||||||||||||
| else: | ||||||||||||||||||
| embed = embed.flatten(start_dim=-2) # [B, S, H*E] | ||||||||||||||||||
| return embed | ||||||||||||||||||
|
|
||||||||||||||||||
| def forward( | ||||||||||||||||||
| self, | ||||||||||||||||||
| history: torch.Tensor, # [B, H] | ||||||||||||||||||
| ) -> torch.Tensor: # [B, 1, P] | ||||||||||||||||||
| """ | ||||||||||||||||||
| Forward pass for recognition mode. | ||||||||||||||||||
| """ | ||||||||||||||||||
| embed = self._forward_embedding(history) | ||||||||||||||||||
| return embed | ||||||||||||||||||
|
|
||||||||||||||||||
| def forward_fullsum( | ||||||||||||||||||
| self, | ||||||||||||||||||
| targets: torch.Tensor, # [B, S] | ||||||||||||||||||
| target_lengths: torch.Tensor, # [B] | ||||||||||||||||||
| ) -> Tuple[torch.Tensor, torch.Tensor]: # [B, S + 1, P], [B] | ||||||||||||||||||
| """ | ||||||||||||||||||
| Forward pass for fullsum training. | ||||||||||||||||||
| """ | ||||||||||||||||||
| non_context_padding = torch.full( | ||||||||||||||||||
| (targets.size(0), self.cfg.context_history_size), | ||||||||||||||||||
| fill_value=self.blank_id, | ||||||||||||||||||
| dtype=targets.dtype, | ||||||||||||||||||
| device=targets.device, | ||||||||||||||||||
| ) # [B, H] | ||||||||||||||||||
| extended_targets = torch.cat([non_context_padding, targets], dim=1) # [B, S+H] | ||||||||||||||||||
| history = torch.stack( | ||||||||||||||||||
| [ | ||||||||||||||||||
| extended_targets[:, self.cfg.context_history_size - 1 - i : (-i if i != 0 else None)] | ||||||||||||||||||
| for i in reversed(range(self.cfg.context_history_size)) | ||||||||||||||||||
| ], | ||||||||||||||||||
| dim=-1, | ||||||||||||||||||
| ) # [B, S+1, H] | ||||||||||||||||||
|
Comment on lines
+124
to
+137
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ChatGPT suggested this code |
||||||||||||||||||
| embed = self._forward_embedding(history) | ||||||||||||||||||
|
|
||||||||||||||||||
| return embed, target_lengths | ||||||||||||||||||
|
|
||||||||||||||||||
| def forward_viterbi( | ||||||||||||||||||
| self, | ||||||||||||||||||
| targets: torch.Tensor, # [B, T] | ||||||||||||||||||
| target_lengths: torch.Tensor, # [B] | ||||||||||||||||||
| ) -> Tuple[torch.Tensor, torch.Tensor]: # [B, T, P], [B] | ||||||||||||||||||
| """ | ||||||||||||||||||
| Forward pass for viterbi training. | ||||||||||||||||||
| """ | ||||||||||||||||||
| B, T = targets.shape | ||||||||||||||||||
| history = torch.zeros( | ||||||||||||||||||
| (B, T, self.cfg.context_history_size), | ||||||||||||||||||
| dtype=targets.dtype, | ||||||||||||||||||
| device=targets.device, | ||||||||||||||||||
| ) # [B, T, H] | ||||||||||||||||||
| recent_labels = torch.full( | ||||||||||||||||||
| (B, self.cfg.context_history_size), | ||||||||||||||||||
| fill_value=self.blank_id, | ||||||||||||||||||
| dtype=targets.dtype, | ||||||||||||||||||
| device=targets.device, | ||||||||||||||||||
| ) # [B, H] | ||||||||||||||||||
|
|
||||||||||||||||||
| for t in range(T): | ||||||||||||||||||
| history[:, t, :] = recent_labels | ||||||||||||||||||
| current_labels = targets[:, t] | ||||||||||||||||||
| non_blank_positions = current_labels != self.blank_id | ||||||||||||||||||
| recent_labels[non_blank_positions, 1:] = recent_labels[non_blank_positions, :-1] | ||||||||||||||||||
| recent_labels[non_blank_positions, 0] = current_labels[non_blank_positions] | ||||||||||||||||||
| embed = self._forward_embedding(history) | ||||||||||||||||||
|
|
||||||||||||||||||
| return embed, target_lengths | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| @dataclass | ||||||||||||||||||
| class FfnnTransducerPredictionNetworkV1Config(EmbeddingTransducerPredictionNetworkV1Config): | ||||||||||||||||||
| """ | ||||||||||||||||||
| Attributes: | ||||||||||||||||||
| ffnn_cfg: Configuration for FFNN prediction network | ||||||||||||||||||
| """ | ||||||||||||||||||
|
|
||||||||||||||||||
| ffnn_cfg: FeedForwardBlockV1Config | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| class FfnnTransducerPredictionNetworkV1(EmbeddingTransducerPredictionNetworkV1): | ||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this class would benefit from using composition instead of inheritance. Make it contain/own an |
||||||||||||||||||
| """ | ||||||||||||||||||
| FfnnTransducerPredictionNetworkV1 with feedforward layers. | ||||||||||||||||||
| """ | ||||||||||||||||||
|
|
||||||||||||||||||
| def __init__(self, cfg: FfnnTransducerPredictionNetworkV1Config): | ||||||||||||||||||
| super().__init__(EmbeddingTransducerPredictionNetworkV1Config.from_child(cfg)) | ||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the first config inherits from the second one, you are able to just:
Suggested change
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. EDIT: With composition instead of inheritance, this comment is no longer relevant. |
||||||||||||||||||
| cfg.ffnn_cfg.input_dim = self.output_dim | ||||||||||||||||||
| self.ffnn = FeedForwardBlockV1(cfg.ffnn_cfg) | ||||||||||||||||||
|
Comment on lines
+191
to
+192
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Leave the configs immutable. Always safer wrt. bugs.
Suggested change
This creates copies of the dataclasses as needed.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. or we could not change anything and throw an error if a value is configured that is wrong.. |
||||||||||||||||||
| self.output_dim = self.ffnn.output_dim | ||||||||||||||||||
|
|
||||||||||||||||||
| def forward( | ||||||||||||||||||
| self, | ||||||||||||||||||
| history: torch.Tensor, # [B, H] | ||||||||||||||||||
| ) -> torch.Tensor: # [B, 1, P] | ||||||||||||||||||
| embed = super().forward(history) | ||||||||||||||||||
| output = self.ffnn(embed) | ||||||||||||||||||
| return output | ||||||||||||||||||
|
|
||||||||||||||||||
| def forward_fullsum( | ||||||||||||||||||
| self, | ||||||||||||||||||
| targets: torch.Tensor, # [B, S] | ||||||||||||||||||
| target_lengths: torch.Tensor, # [B] | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the |
||||||||||||||||||
| ) -> Tuple[torch.Tensor, torch.Tensor]: # [B, S + 1, P], [B] | ||||||||||||||||||
| embed, _ = super().forward_fullsum(targets, target_lengths) | ||||||||||||||||||
| output = self.ffnn(embed) | ||||||||||||||||||
| return output, target_lengths | ||||||||||||||||||
|
|
||||||||||||||||||
| def forward_viterbi( | ||||||||||||||||||
| self, | ||||||||||||||||||
| targets: torch.Tensor, # [B, T] | ||||||||||||||||||
| target_lengths: torch.Tensor, # [B] | ||||||||||||||||||
| ) -> Tuple[torch.Tensor, torch.Tensor]: # [B, T, P], [B] | ||||||||||||||||||
| embed, _ = super().forward_viterbi(targets, target_lengths) | ||||||||||||||||||
| output = self.ffnn(embed) | ||||||||||||||||||
| return output, target_lengths | ||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,8 +1,13 @@ | ||||||
| __all__ = ["FeedForwardConfig", "FeedForwardModel"] | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about these old items here? Why are they gone from |
||||||
| __all__ = [ | ||||||
| "FeedForwardLayerV1Config", | ||||||
| "FeedForwardLayerV1", | ||||||
| "FeedForwardBlockV1Config", | ||||||
| "FeedForwardBlockV1", | ||||||
| ] | ||||||
|
|
||||||
| from dataclasses import dataclass | ||||||
| from functools import partial | ||||||
| from typing import Callable, Optional, Tuple, Union | ||||||
| from typing import Callable, Optional, Tuple, Union, List | ||||||
|
|
||||||
| import torch | ||||||
| from torch import nn | ||||||
|
|
@@ -57,3 +62,66 @@ def forward( | |||||
| tensor = self.activation(tensor) # [B,T,F] | ||||||
| tensor = self.dropout(tensor) # [B,T,F] | ||||||
| return tensor, sequence_mask | ||||||
|
|
||||||
|
|
||||||
| @dataclass | ||||||
| class FeedForwardBlockV1Config(ModelConfiguration): | ||||||
| """ | ||||||
| Configuration for the FeedForwardBlockV1 module. | ||||||
|
|
||||||
| Attributes: | ||||||
| input_dim: Input feature dimension. | ||||||
| layer_sizes: List of hidden layer sizes. The length of this list | ||||||
| determines the number of layers. | ||||||
| dropouts: Dropout probability for each layer. | ||||||
| layer_activations: List of activation function applied after each linear layer. | ||||||
| None represents no activation. | ||||||
| Must have the same length as layer_sizes. | ||||||
| use_layer_norm: Whether to use Layer Normalization. | ||||||
| """ | ||||||
|
|
||||||
| input_dim: int | ||||||
| layer_sizes: List[int] | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I generally prefer
Suggested change
in config places like this as that also is correct for tuples, and I think tuples fit the configs better because they are immutable (but this is debatable). |
||||||
| dropouts: List[float] | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you customizing these values per layer? If not, consider allowing the simple one-value-fits-all variant as well like:
Suggested change
|
||||||
| layer_activations: List[Optional[Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]]]] | ||||||
| use_layer_norm: bool = True | ||||||
|
|
||||||
| def __post_init__(self): | ||||||
| super().__post_init__() | ||||||
| assert all(0.0 <= dropout <= 1.0 for dropout in self.dropouts), "Dropout values must be probabilities" | ||||||
| assert len(self.layer_sizes) > 0, "layer_sizes must not be empty" | ||||||
| assert len(self.layer_sizes) == len(self.layer_activations) | ||||||
| assert len(self.layer_sizes) == len(self.dropouts) | ||||||
|
Comment on lines
+89
to
+94
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice! |
||||||
|
|
||||||
|
|
||||||
| class FeedForwardBlockV1(nn.Module): | ||||||
| """ | ||||||
| A multi-layer feed-forward network block with optional Layer Normalization. | ||||||
| """ | ||||||
|
|
||||||
| def __init__(self, cfg: FeedForwardBlockV1Config): | ||||||
| super().__init__() | ||||||
| self.cfg = cfg | ||||||
| network_layers: List[nn.Module] = [] | ||||||
| prev_size = cfg.input_dim | ||||||
|
|
||||||
| for i, layer_size in enumerate(cfg.layer_sizes): | ||||||
| if cfg.use_layer_norm: | ||||||
| network_layers.append(nn.LayerNorm(prev_size)) | ||||||
| network_layers.append(nn.Linear(prev_size, layer_size)) | ||||||
| prev_size = layer_size | ||||||
| if cfg.layer_activations[i] is not None: | ||||||
| network_layers.append(cfg.layer_activations[i]) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a feeling we should be using |
||||||
| network_layers.append(nn.Dropout(cfg.dropouts[i])) | ||||||
|
|
||||||
| self.output_dim = cfg.layer_sizes[-1] | ||||||
| self.network = nn.Sequential(*network_layers) | ||||||
|
|
||||||
| def forward(self, tensor: torch.Tensor) -> torch.Tensor: | ||||||
| """ | ||||||
| Forward pass through the feed-forward block. | ||||||
|
|
||||||
| :param tensor: Input tensor of shape [B, T, F], where F is input_dim. | ||||||
| :return: Output tensor of shape [B, T, output_dim]. | ||||||
| """ | ||||||
| return self.network(tensor) | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs docs.