Skip to content

Commit 324fc4a

Browse files
committed
revertme: debug shapes
1 parent 9ab3e9b commit 324fc4a

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

i6_models/parts/conformer/mhsa_rel_pos.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,10 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to
216216
v = value_seq.view(batch_dim_size, -1, self.num_heads, self.embed_dim_per_head) # [B, T, H, F']
217217

218218
# attention matrix a and c is computed inside torch's sdpa
219+
print("q_with_bias_u", q_with_bias_u.transpose(-3, -2))
220+
print("k", k.transpose(-3, -2))
221+
print("v", v.transpose(-3, -2))
222+
print("attn_bd_mask_scaled", attn_bd_mask_scaled.transpose(-3, -2))
219223
attn_output = F.scaled_dot_product_attention(
220224
q_with_bias_u.transpose(-3, -2), # [B, #heads, T, F']
221225
k.transpose(-3, -2), # [B, #heads, T', F']

0 commit comments

Comments
 (0)