|
2 | 2 |
|
3 | 3 | __all__ = ["ConformerMHSAV1", "ConformerMHSAV1Config"] |
4 | 4 | from dataclasses import dataclass |
5 | | -from typing import Optional |
6 | 5 | import torch |
7 | 6 |
|
8 | 7 | from i6_models.config import ModelConfiguration |
| 8 | +from i6_models.util import compat |
9 | 9 |
|
10 | 10 |
|
11 | 11 | @dataclass |
@@ -43,17 +43,19 @@ def __init__(self, cfg: ConformerMHSAV1Config): |
43 | 43 | ) |
44 | 44 | self.dropout = cfg.dropout |
45 | 45 |
|
46 | | - def forward(self, input_tensor: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| 46 | + def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> torch.Tensor: |
47 | 47 | """ |
48 | 48 | Apply layer norm and multi-head self attention and dropout |
49 | | - :param Optional[torch.Tensor] key_padding_mask: could be a binary or float mask of shape (B, T) |
| 49 | +
|
| 50 | + :param input_tensor: Input to the self attention of shape (B, T, F) |
| 51 | + :param sequence_mask: bool mask of shape (B, T), True signals within sequence, False outside, will be inverted to match the torch.nn.MultiheadAttention module |
50 | 52 | which will be applied/added to dot product, used to mask padded key positions out |
51 | 53 | """ |
52 | | - |
| 54 | + inv_sequence_mask = compat.logical_not(sequence_mask) |
53 | 55 | output_tensor = self.layernorm(input_tensor) # [B,T,F] |
54 | 56 |
|
55 | 57 | output_tensor, _ = self.mhsa( |
56 | | - output_tensor, output_tensor, output_tensor, key_padding_mask=key_padding_mask, need_weights=False |
| 58 | + output_tensor, output_tensor, output_tensor, key_padding_mask=inv_sequence_mask, need_weights=False |
57 | 59 | ) # [B,T,F] |
58 | 60 | output_tensor = torch.nn.functional.dropout(output_tensor, p=self.dropout, training=self.training) # [B,T,F] |
59 | 61 |
|
|
0 commit comments