Skip to content

Commit ab4b708

Browse files
authored
update docstrings (#20)
1 parent 2e84826 commit ab4b708

File tree

3 files changed

+25
-13
lines changed

3 files changed

+25
-13
lines changed

i6_models/parts/conformer/convolution.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,20 @@
1212

1313
@dataclass
1414
class ConformerConvolutionV1Config(ModelConfiguration):
15+
"""
16+
Attributes:
17+
channels: number of channels for conv layers
18+
kernel_size: kernel size of conv layers
19+
dropout: dropout probability
20+
activation: activation function applied after normalization
21+
norm: normalization layer with input of shape [N,C,T]
22+
"""
23+
1524
channels: int
16-
"""number of channels for conv layers"""
1725
kernel_size: int
18-
"""kernel size of conv layers"""
1926
dropout: float
20-
"""dropout probability"""
2127
activation: Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]]
22-
"""activation function applied after norm"""
2328
norm: Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]]
24-
"""normalization layer with input of shape [N,C,T]"""
2529

2630
def check_valid(self):
2731
assert self.kernel_size % 2 == 1, "ConformerConvolutionV1 only supports odd kernel sizes"

i6_models/parts/conformer/feedforward.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,18 @@
1313

1414
@dataclass
1515
class ConformerPositionwiseFeedForwardV1Config(ModelConfiguration):
16+
"""
17+
Attributes:
18+
input_dim: input dimension
19+
hidden_dim: hidden dimension (normally set to 4*input_dim as suggested by the paper)
20+
dropout: dropout probability
21+
activation: activation function
22+
"""
23+
1624
input_dim: int
17-
"""input dimension"""
1825
hidden_dim: int
19-
"""hidden dimension (normally set to 4*input_dim as suggested by the paper)"""
2026
dropout: float
21-
"""dropout probability"""
2227
activation: Callable[[torch.Tensor], torch.Tensor] = nn.functional.silu
23-
"""activation function"""
2428

2529

2630
class ConformerPositionwiseFeedForwardV1(nn.Module):

i6_models/parts/conformer/mhsa.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,18 @@
1010

1111
@dataclass
1212
class ConformerMHSAV1Config(ModelConfiguration):
13+
"""
14+
Attributes:
15+
input_dim: input dim and total dimension for query/key and value projections, should be divisible by `num_att_heads`
16+
num_att_heads: number of attention heads
17+
att_weights_dropout: attention weights dropout
18+
dropout: multi-headed self attention output dropout
19+
"""
20+
1321
input_dim: int
14-
"""input dim and total dimension for query/key and value projections, should be divisible by `num_att_heads`"""
1522
num_att_heads: int
16-
"""number of attention heads"""
1723
att_weights_dropout: float
18-
"""attention weights dropout"""
1924
dropout: float
20-
"""multi-headed self attention output dropout"""
2125

2226
def __post_init__(self) -> None:
2327
super().__post_init__()

0 commit comments

Comments
 (0)