|
11 | 11 | import torch.nn.functional as F |
12 | 12 |
|
13 | 13 | from i6_models.config import ModelConfiguration |
14 | | -from i6_models.util import compat |
15 | 14 | from i6_models.parts.dropout import BroadcastDropout |
| 15 | +from i6_models.util import compat |
16 | 16 |
|
17 | 17 |
|
18 | 18 | @dataclass |
@@ -195,31 +195,41 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to |
195 | 195 | q_with_bias_u = q + self.pos_bias_u if self.with_pos_bias else q # [B, T, #heads, F'] |
196 | 196 | q_with_bias_v = q + self.pos_bias_v if self.with_pos_bias else q |
197 | 197 |
|
198 | | - # attention matrix a and c |
199 | | - attn_ac = torch.einsum("bihf, bjhf -> bhij", q_with_bias_u, k) # [B, #heads, T, T'] |
200 | | - |
201 | 198 | # attention matrix b and d |
202 | 199 | attn_bd = torch.einsum( |
203 | 200 | "bihf, ijhf -> bhij", q_with_bias_v, rel_pos_embeddings |
204 | 201 | ) # [B, #heads, T, T'] or [B, #heads, T, T+T'+1] |
205 | | - |
206 | 202 | if not self.learnable_pos_emb: |
207 | 203 | attn_bd = self._rel_shift_bhij(attn_bd, k_len=time_dim_size) # [B, #heads, T, T'] |
208 | 204 |
|
209 | | - attn = attn_ac + attn_bd + mask # [B, #heads, T, T'] |
210 | | - attn_scaled = attn * (math.sqrt(1.0 / float(self.embed_dim_per_head))) # [B, #heads, T, T'] |
| 205 | + # We use attn_mask to add BD matrix to attention scores. |
| 206 | + # |
| 207 | + # Inside torch's SDPA the mask is added after regular scaling, so to get correct |
| 208 | + # results, we need to apply the scaling here before passing to SDPA. |
| 209 | + # |
| 210 | + # See for reference: |
| 211 | + # https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html |
| 212 | + attn_bd_mask = attn_bd + mask |
| 213 | + scale = math.sqrt(1.0 / float(self.embed_dim_per_head)) |
| 214 | + attn_bd_mask_scaled = attn_bd_mask * scale |
211 | 215 |
|
212 | | - # softmax and dropout |
213 | | - attn_output_weights = self.att_weights_dropout(F.softmax(attn_scaled, dim=-1)) # [B, #heads, T, T'] |
214 | | - |
215 | | - # sequence of weighted sums over value sequence |
216 | 216 | v = value_seq.view(batch_dim_size, -1, self.num_heads, self.embed_dim_per_head) # [B, T, H, F'] |
217 | | - attn_output = torch.einsum("bhij, bjhf -> bihf", attn_output_weights, v).reshape( |
218 | | - batch_dim_size, -1, self.embed_dim |
219 | | - ) |
220 | 217 |
|
221 | | - output_tensor = self.out_proj(attn_output) |
| 218 | + # Use torch's SDPA for efficiency. |
| 219 | + # |
| 220 | + # The attention matrices a and c are computed inside torch's sdpa. |
| 221 | + attn_output = F.scaled_dot_product_attention( |
| 222 | + q_with_bias_u.transpose(-3, -2), # [B, #heads, T, F'] |
| 223 | + k.transpose(-3, -2), # [B, #heads, T', F'] |
| 224 | + v.transpose(-3, -2), # [B, #heads, T, F'] |
| 225 | + attn_mask=attn_bd_mask_scaled, # [B, #heads, T, T'] |
| 226 | + dropout_p=self.att_weights_dropout.p if self.training else 0.0, |
| 227 | + scale=scale, |
| 228 | + ) # [B, #heads, T, F'] |
| 229 | + attn_output = attn_output.transpose(-3, -2).flatten(-2) # [B, T, F'] |
| 230 | + assert attn_output.shape[-1] == self.embed_dim |
222 | 231 |
|
| 232 | + output_tensor = self.out_proj(attn_output) |
223 | 233 | output_tensor = self.dropout(output_tensor) |
224 | 234 |
|
225 | 235 | return output_tensor # [B,T,F] |
|
0 commit comments