Skip to content

Commit 58699f6

Browse files
committed
Use torch's fused SDPA for attention computation
1 parent ff22171 commit 58699f6

File tree

1 file changed

+25
-15
lines changed

1 file changed

+25
-15
lines changed

i6_models/parts/conformer/mhsa_rel_pos.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
import torch.nn.functional as F
1212

1313
from i6_models.config import ModelConfiguration
14-
from i6_models.util import compat
1514
from i6_models.parts.dropout import BroadcastDropout
15+
from i6_models.util import compat
1616

1717

1818
@dataclass
@@ -195,31 +195,41 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to
195195
q_with_bias_u = q + self.pos_bias_u if self.with_pos_bias else q # [B, T, #heads, F']
196196
q_with_bias_v = q + self.pos_bias_v if self.with_pos_bias else q
197197

198-
# attention matrix a and c
199-
attn_ac = torch.einsum("bihf, bjhf -> bhij", q_with_bias_u, k) # [B, #heads, T, T']
200-
201198
# attention matrix b and d
202199
attn_bd = torch.einsum(
203200
"bihf, ijhf -> bhij", q_with_bias_v, rel_pos_embeddings
204201
) # [B, #heads, T, T'] or [B, #heads, T, T+T'+1]
205-
206202
if not self.learnable_pos_emb:
207203
attn_bd = self._rel_shift_bhij(attn_bd, k_len=time_dim_size) # [B, #heads, T, T']
208204

209-
attn = attn_ac + attn_bd + mask # [B, #heads, T, T']
210-
attn_scaled = attn * (math.sqrt(1.0 / float(self.embed_dim_per_head))) # [B, #heads, T, T']
205+
# We use attn_mask to add BD matrix to attention scores.
206+
#
207+
# Inside torch's SDPA the mask is added after regular scaling, so to get correct
208+
# results, we need to apply the scaling here before passing to SDPA.
209+
#
210+
# See for reference:
211+
# https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
212+
attn_bd_mask = attn_bd + mask
213+
scale = math.sqrt(1.0 / float(self.embed_dim_per_head))
214+
attn_bd_mask_scaled = attn_bd_mask * scale
211215

212-
# softmax and dropout
213-
attn_output_weights = self.att_weights_dropout(F.softmax(attn_scaled, dim=-1)) # [B, #heads, T, T']
214-
215-
# sequence of weighted sums over value sequence
216216
v = value_seq.view(batch_dim_size, -1, self.num_heads, self.embed_dim_per_head) # [B, T, H, F']
217-
attn_output = torch.einsum("bhij, bjhf -> bihf", attn_output_weights, v).reshape(
218-
batch_dim_size, -1, self.embed_dim
219-
)
220217

221-
output_tensor = self.out_proj(attn_output)
218+
# Use torch's SDPA for efficiency.
219+
#
220+
# The attention matrices a and c are computed inside torch's sdpa.
221+
attn_output = F.scaled_dot_product_attention(
222+
q_with_bias_u.transpose(-3, -2), # [B, #heads, T, F']
223+
k.transpose(-3, -2), # [B, #heads, T', F']
224+
v.transpose(-3, -2), # [B, #heads, T, F']
225+
attn_mask=attn_bd_mask_scaled, # [B, #heads, T, T']
226+
dropout_p=self.att_weights_dropout.p if self.training else 0.0,
227+
scale=scale,
228+
) # [B, #heads, T, F']
229+
attn_output = attn_output.transpose(-3, -2).flatten(-2) # [B, T, F']
230+
assert attn_output.shape[-1] == self.embed_dim
222231

232+
output_tensor = self.out_proj(attn_output)
223233
output_tensor = self.dropout(output_tensor)
224234

225235
return output_tensor # [B,T,F]

0 commit comments

Comments
 (0)