@@ -166,6 +166,37 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
166
166
return self ._cached_decode_metadata
167
167
168
168
169
+ def _make_alibi_bias (alibi_slopes : torch .Tensor ,
170
+ dtype : torch .dtype ,
171
+ seq_lens : Optional [List [int ]],
172
+ make_attn_mask : bool = True ) -> List [torch .Tensor ]:
173
+ attn_biases = []
174
+ if seq_lens :
175
+ for seq_len in seq_lens :
176
+ bias = torch .arange (seq_len , dtype = dtype )
177
+ # NOTE(zhuohan): HF uses
178
+ # `bias = bias[None, :].repeat(seq_len, 1)`
179
+ # here. We find that both biases give the same results, but
180
+ # the bias below more accurately follows the original ALiBi
181
+ # paper.
182
+ bias = bias [None , :] - bias [:, None ]
183
+
184
+ num_heads = alibi_slopes .shape [0 ]
185
+ bias = bias [None , :].repeat (
186
+ (num_heads , 1 , 1 )).to (alibi_slopes .device )
187
+ bias .mul_ (alibi_slopes [:, None , None ])
188
+ if make_attn_mask :
189
+ inf_mask = torch .empty (
190
+ (1 , seq_len , seq_len ),
191
+ dtype = bias .dtype ).fill_ (- torch .inf ).triu_ (diagonal = 1 ).to (
192
+ alibi_slopes .device )
193
+ attn_biases .append ((bias + inf_mask ).to (dtype ))
194
+ else :
195
+ attn_biases .append (bias .to (dtype ))
196
+
197
+ return attn_biases
198
+
199
+
169
200
class ROCmFlashAttentionImpl (AttentionImpl ):
170
201
"""
171
202
If the input tensors contain prompt tokens, the layout is as follows:
@@ -324,7 +355,14 @@ def forward(
324
355
# triton attention
325
356
# When block_tables are not filled, it means q and k are the
326
357
# prompt, and they have the same length.
358
+ attn_masks = None
327
359
if self .use_triton_flash_attn :
360
+ if self .alibi_slopes is not None :
361
+ attn_masks = _make_alibi_bias (
362
+ self .alibi_slopes ,
363
+ query .dtype ,
364
+ attn_metadata .seq_lens ,
365
+ make_attn_mask = False ) # type: ignore
328
366
out , _ = self .attn_func (
329
367
query ,
330
368
key ,
@@ -336,12 +374,20 @@ def forward(
336
374
prefill_meta .max_prefill_seq_len ,
337
375
True ,
338
376
self .scale ,
377
+ attn_masks [0 ][None ]
378
+ if attn_masks is not None else None ,
339
379
)
340
380
elif self .use_naive_attn :
341
381
if self .num_kv_heads != self .num_heads :
342
382
# Interleave for MQA workaround.
343
383
key = self .repeat_kv (key , self .num_queries_per_kv )
344
384
value = self .repeat_kv (value , self .num_queries_per_kv )
385
+ if self .alibi_slopes is not None :
386
+ attn_masks = _make_alibi_bias (
387
+ self .alibi_slopes ,
388
+ query .dtype ,
389
+ attn_metadata .seq_lens ,
390
+ make_attn_mask = True ) # type: ignore
345
391
query = query .movedim (0 , query .dim () - 2 )
346
392
key = key .movedim (0 , key .dim () - 2 )
347
393
value = value .movedim (0 , value .dim () - 2 )
@@ -355,6 +401,7 @@ def forward(
355
401
self .num_heads ,
356
402
self .head_size ,
357
403
self .scale ,
404
+ attn_masks ,
358
405
)
359
406
else :
360
407
out = self .attn_func (
@@ -418,13 +465,14 @@ def _sdpa_attention(
418
465
num_heads : int ,
419
466
head_size : int ,
420
467
scale : float ,
468
+ attn_masks : Optional [List [torch .Tensor ]] = None ,
421
469
) -> torch .Tensor :
422
470
start = 0
423
471
output = torch .empty ((num_tokens , num_heads , head_size ),
424
472
dtype = query .dtype ,
425
473
device = query .device )
426
474
427
- for seq_len in seq_lens :
475
+ for i , seq_len in enumerate ( seq_lens ) :
428
476
end = start + seq_len
429
477
with torch .backends .cuda .sdp_kernel (enable_math = True ,
430
478
enable_flash = False ,
@@ -434,7 +482,8 @@ def _sdpa_attention(
434
482
key [:, start :end , :],
435
483
value [:, start :end , :],
436
484
dropout_p = 0.0 ,
437
- is_causal = True ,
485
+ is_causal = attn_masks is None ,
486
+ attn_mask = attn_masks [i ] if attn_masks else None ,
438
487
scale = scale ).movedim (query .dim () - 2 , 0 )
439
488
output [start :end , :, :] = sub_out
440
489
start = end
0 commit comments