Skip to content

Commit 6bd8ebf

Browse files
mxz297mgoingemini-code-assist[bot]
authored
[Kernel][AMD] Avoid D2H copy and cumsum kernel (#22683)
Signed-off-by: Xiaozhu <[email protected]> Signed-off-by: Michael Goin <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent dab4f9f commit 6bd8ebf

File tree

1 file changed

+20
-12
lines changed

1 file changed

+20
-12
lines changed

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -214,12 +214,14 @@ class AiterFlashAttentionMetadata:
214214
# |-- query_len ---|
215215

216216
num_actual_tokens: int # Number of tokens excluding padding.
217+
num_actual_kv_tokens: int
217218
max_query_len: int
218219
query_start_loc: torch.Tensor
219220
max_seq_len: int
220221
seq_lens: torch.Tensor
221222
slot_mapping: torch.Tensor
222223
block_table: torch.Tensor
224+
cu_seq_lens: Optional[torch.Tensor]
223225

224226
# For cascade attention.
225227
use_cascade: bool
@@ -272,6 +274,20 @@ def build(self,
272274
seq_lens = common_attn_metadata.seq_lens
273275
block_table_tensor = common_attn_metadata.block_table_tensor
274276
slot_mapping = common_attn_metadata.slot_mapping
277+
if max_query_len > 1:
278+
# We pre-compute cumulative seq len needed for prefill attention
279+
# here to avoid recomputing it for every layer
280+
cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1,
281+
dtype=torch.int32,
282+
device=seq_lens.device)
283+
torch.cumsum(seq_lens,
284+
dim=0,
285+
dtype=cu_seq_lens.dtype,
286+
out=cu_seq_lens[1:])
287+
num_actual_kv_tokens = int(cu_seq_lens[-1].item())
288+
else:
289+
cu_seq_lens = None
290+
num_actual_kv_tokens = 0
275291

276292
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
277293
max_seq_len, causal):
@@ -281,12 +297,14 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
281297

282298
attn_metadata = AiterFlashAttentionMetadata(
283299
num_actual_tokens=num_actual_tokens,
300+
num_actual_kv_tokens=num_actual_kv_tokens,
284301
max_query_len=max_query_len,
285302
query_start_loc=query_start_loc,
286303
max_seq_len=max_seq_len,
287304
seq_lens=seq_lens,
288305
block_table=block_table_tensor,
289306
slot_mapping=slot_mapping,
307+
cu_seq_lens=cu_seq_lens,
290308
use_cascade=use_cascade,
291309
common_prefix_len=common_prefix_len,
292310
total_tokens=self.total_tokens,
@@ -475,16 +493,6 @@ def forward(
475493
block_table = attn_metadata.block_table
476494

477495
if max_seqlen_q > 1:
478-
479-
cu_seq_lens = torch.zeros(seqused_k.shape[0] + 1,
480-
dtype=torch.int32,
481-
device=query.device)
482-
483-
torch.cumsum(seqused_k,
484-
dim=0,
485-
dtype=cu_seq_lens.dtype,
486-
out=cu_seq_lens[1:])
487-
488496
torch.ops.vllm.flash_attn_varlen_func(
489497
query[:num_actual_tokens],
490498
key_cache,
@@ -497,10 +505,10 @@ def forward(
497505
alibi_slopes=self.alibi_slopes,
498506
window_size=self.sliding_window,
499507
block_table=block_table,
500-
cu_seqlens_k=cu_seq_lens,
508+
cu_seqlens_k=attn_metadata.cu_seq_lens,
501509
k_scale=layer._k_scale,
502510
v_scale=layer._v_scale,
503-
total_tokens=attn_metadata.total_tokens,
511+
total_tokens=attn_metadata.num_actual_kv_tokens,
504512
)
505513

506514
_, num_heads, head_size = query.shape

0 commit comments

Comments
 (0)