@@ -52,7 +52,7 @@ def __init__(self, cfg: ConformerBlockV1Config):
5252 def forward (self , x : torch .Tensor , / , sequence_mask : torch .Tensor ) -> torch .Tensor :
5353 """
5454 :param x: input tensor of shape [B, T, F]
55- :param sequence_mask: mask tensor where 0 defines positions within the sequence and 1 outside, shape: [B, T]
55+ :param sequence_mask: mask tensor where 1 defines positions within the sequence and 0 outside, shape: [B, T]
5656 :return: torch.Tensor of shape [B, T, F]
5757 """
5858 x = 0.5 * self .ff1 (x ) + x # [B, T, F]
@@ -98,7 +98,7 @@ def __init__(self, cfg: ConformerEncoderV1Config):
9898 def forward (self , data_tensor : torch .Tensor , sequence_mask : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
9999 """
100100 :param data_tensor: input tensor of shape [B, T', F]
101- :param sequence_mask: mask tensor where 0 defines positions within the sequence and 1 outside, shape: [B, T']
101+ :param sequence_mask: mask tensor where 1 defines positions within the sequence and 0 outside, shape: [B, T']
102102 :return: (output, out_seq_mask)
103103 where output is torch.Tensor of shape [B, T, F'],
104104 out_seq_mask is a torch.Tensor of shape [B, T]
0 commit comments