Skip to content

Commit 56b325e

Browse files
gshtrashongxiayang
andauthored
[ROCm][AMD][Model]Adding alibi slopes support in ROCm triton flash attention and naive flash attention (#6043)
Co-authored-by: Hongxia Yang <[email protected]>
1 parent 3dd5070 commit 56b325e

File tree

1 file changed

+51
-2
lines changed

1 file changed

+51
-2
lines changed

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,37 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
166166
return self._cached_decode_metadata
167167

168168

169+
def _make_alibi_bias(alibi_slopes: torch.Tensor,
170+
dtype: torch.dtype,
171+
seq_lens: Optional[List[int]],
172+
make_attn_mask: bool = True) -> List[torch.Tensor]:
173+
attn_biases = []
174+
if seq_lens:
175+
for seq_len in seq_lens:
176+
bias = torch.arange(seq_len, dtype=dtype)
177+
# NOTE(zhuohan): HF uses
178+
# `bias = bias[None, :].repeat(seq_len, 1)`
179+
# here. We find that both biases give the same results, but
180+
# the bias below more accurately follows the original ALiBi
181+
# paper.
182+
bias = bias[None, :] - bias[:, None]
183+
184+
num_heads = alibi_slopes.shape[0]
185+
bias = bias[None, :].repeat(
186+
(num_heads, 1, 1)).to(alibi_slopes.device)
187+
bias.mul_(alibi_slopes[:, None, None])
188+
if make_attn_mask:
189+
inf_mask = torch.empty(
190+
(1, seq_len, seq_len),
191+
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(
192+
alibi_slopes.device)
193+
attn_biases.append((bias + inf_mask).to(dtype))
194+
else:
195+
attn_biases.append(bias.to(dtype))
196+
197+
return attn_biases
198+
199+
169200
class ROCmFlashAttentionImpl(AttentionImpl):
170201
"""
171202
If the input tensors contain prompt tokens, the layout is as follows:
@@ -324,7 +355,14 @@ def forward(
324355
# triton attention
325356
# When block_tables are not filled, it means q and k are the
326357
# prompt, and they have the same length.
358+
attn_masks = None
327359
if self.use_triton_flash_attn:
360+
if self.alibi_slopes is not None:
361+
attn_masks = _make_alibi_bias(
362+
self.alibi_slopes,
363+
query.dtype,
364+
attn_metadata.seq_lens,
365+
make_attn_mask=False) # type: ignore
328366
out, _ = self.attn_func(
329367
query,
330368
key,
@@ -336,12 +374,20 @@ def forward(
336374
prefill_meta.max_prefill_seq_len,
337375
True,
338376
self.scale,
377+
attn_masks[0][None]
378+
if attn_masks is not None else None,
339379
)
340380
elif self.use_naive_attn:
341381
if self.num_kv_heads != self.num_heads:
342382
# Interleave for MQA workaround.
343383
key = self.repeat_kv(key, self.num_queries_per_kv)
344384
value = self.repeat_kv(value, self.num_queries_per_kv)
385+
if self.alibi_slopes is not None:
386+
attn_masks = _make_alibi_bias(
387+
self.alibi_slopes,
388+
query.dtype,
389+
attn_metadata.seq_lens,
390+
make_attn_mask=True) # type: ignore
345391
query = query.movedim(0, query.dim() - 2)
346392
key = key.movedim(0, key.dim() - 2)
347393
value = value.movedim(0, value.dim() - 2)
@@ -355,6 +401,7 @@ def forward(
355401
self.num_heads,
356402
self.head_size,
357403
self.scale,
404+
attn_masks,
358405
)
359406
else:
360407
out = self.attn_func(
@@ -418,13 +465,14 @@ def _sdpa_attention(
418465
num_heads: int,
419466
head_size: int,
420467
scale: float,
468+
attn_masks: Optional[List[torch.Tensor]] = None,
421469
) -> torch.Tensor:
422470
start = 0
423471
output = torch.empty((num_tokens, num_heads, head_size),
424472
dtype=query.dtype,
425473
device=query.device)
426474

427-
for seq_len in seq_lens:
475+
for i, seq_len in enumerate(seq_lens):
428476
end = start + seq_len
429477
with torch.backends.cuda.sdp_kernel(enable_math=True,
430478
enable_flash=False,
@@ -434,7 +482,8 @@ def _sdpa_attention(
434482
key[:, start:end, :],
435483
value[:, start:end, :],
436484
dropout_p=0.0,
437-
is_causal=True,
485+
is_causal=attn_masks is None,
486+
attn_mask=attn_masks[i] if attn_masks else None,
438487
scale=scale).movedim(query.dim() - 2, 0)
439488
output[start:end, :, :] = sub_out
440489
start = end

0 commit comments

Comments
 (0)