Skip to content

Commit 7e345a1

Browse files
Atticus1806michelwialbertz
authored
Fix MHSA masking (#22)
* fix masking * bool * None and tests * cast to float * Update i6_models/parts/conformer/mhsa.py Co-authored-by: michelwi <[email protected]> * Update tests/test_conformer.py Co-authored-by: michelwi <[email protected]> * Update i6_models/parts/conformer/mhsa.py Co-authored-by: Albert Zeyer <[email protected]> * add support for export * update doc * remove optional and onnx distinction * remove test without mask * Update i6_models/parts/conformer/mhsa.py Co-authored-by: Albert Zeyer <[email protected]> * move to compat * doc * Update i6_models/util/compat.py Co-authored-by: Albert Zeyer <[email protected]> * updates * Update i6_models/parts/conformer/mhsa.py Co-authored-by: Albert Zeyer <[email protected]> * Update i6_models/parts/conformer/mhsa.py Co-authored-by: Albert Zeyer <[email protected]> * updates * doc * Update i6_models/util/compat.py Co-authored-by: Albert Zeyer <[email protected]> --------- Co-authored-by: michelwi <[email protected]> Co-authored-by: Albert Zeyer <[email protected]>
1 parent 4ce5419 commit 7e345a1

File tree

3 files changed

+26
-11
lines changed

3 files changed

+26
-11
lines changed

i6_models/parts/conformer/mhsa.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
__all__ = ["ConformerMHSAV1", "ConformerMHSAV1Config"]
44
from dataclasses import dataclass
5-
from typing import Optional
65
import torch
76

87
from i6_models.config import ModelConfiguration
8+
from i6_models.util import compat
99

1010

1111
@dataclass
@@ -43,17 +43,19 @@ def __init__(self, cfg: ConformerMHSAV1Config):
4343
)
4444
self.dropout = cfg.dropout
4545

46-
def forward(self, input_tensor: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
46+
def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> torch.Tensor:
4747
"""
4848
Apply layer norm and multi-head self attention and dropout
49-
:param Optional[torch.Tensor] key_padding_mask: could be a binary or float mask of shape (B, T)
49+
50+
:param input_tensor: Input to the self attention of shape (B, T, F)
51+
:param sequence_mask: bool mask of shape (B, T), True signals within sequence, False outside, will be inverted to match the torch.nn.MultiheadAttention module
5052
which will be applied/added to dot product, used to mask padded key positions out
5153
"""
52-
54+
inv_sequence_mask = compat.logical_not(sequence_mask)
5355
output_tensor = self.layernorm(input_tensor) # [B,T,F]
5456

5557
output_tensor, _ = self.mhsa(
56-
output_tensor, output_tensor, output_tensor, key_padding_mask=key_padding_mask, need_weights=False
58+
output_tensor, output_tensor, output_tensor, key_padding_mask=inv_sequence_mask, need_weights=False
5759
) # [B,T,F]
5860
output_tensor = torch.nn.functional.dropout(output_tensor, p=self.dropout, training=self.training) # [B,T,F]
5961

i6_models/util/compat.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""
2+
Compatibility support for different functions. This could be for example for onnx export.
3+
"""
4+
5+
import torch
6+
7+
8+
def logical_not(tensor: torch.Tensor, /) -> torch.Tensor:
9+
"""
10+
Helper function to decide how to invert the sequence mask. For ONNX export use XOR with 1 since logical_not is not implemented.
11+
Else logical_not is applied for efficiency reasons.
12+
13+
:param tensor: bool mask of shape (B, T) to be inverted.
14+
"""
15+
if torch.onnx.is_in_onnx_export():
16+
return torch.logical_xor(tensor, torch.ones_like(tensor))
17+
else:
18+
return torch.logical_not(tensor)

tests/test_conformer.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,10 @@ def get_output_shape(input_shape, cfg, **kwargs):
5757

5858
return list(output.shape)
5959

60-
# without key padding mask
61-
input_shape = [3, 10, 20] # B,T,F
62-
cfg = ConformerMHSAV1Config(20, 4, 0.1, 0.1)
63-
assert get_output_shape(input_shape, cfg) == [3, 10, 20]
64-
6560
# with key padding mask
6661
input_shape = [4, 15, 32] # B,T,F
6762
cfg = ConformerMHSAV1Config(32, 8, 0.2, 0.3)
68-
assert get_output_shape(input_shape, cfg, key_padding_mask=torch.randint(0, 2, input_shape[:2]) > 0) == [4, 15, 32]
63+
assert get_output_shape(input_shape, cfg, sequence_mask=(torch.randint(0, 2, input_shape[:2]) > 0)) == [4, 15, 32]
6964

7065

7166
def test_layer_norm_nc():

0 commit comments

Comments
 (0)