Skip to content

Commit 9f64e93

Browse files
authored
[BugFix][AMD] Compatible patch for latest AITER(05/07/2025) (#17864)
Signed-off-by: Qiang Li <[email protected]>
1 parent ec61ea2 commit 9f64e93

File tree

4 files changed

+54
-23
lines changed

4 files changed

+54
-23
lines changed

vllm/attention/backends/mla/common.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,9 +1213,9 @@ def _compute_prefill_context(
12131213

12141214
attn_output, attn_softmax_lse = \
12151215
self._flash_attn_varlen_diff_headdims(
1216-
q=q,
1217-
k=k,
1218-
v=v,
1216+
q,
1217+
k,
1218+
v,
12191219
cu_seqlens_q=prefill_metadata.query_start_loc,
12201220
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
12211221
max_seqlen_q=prefill_metadata.max_query_len,
@@ -1267,9 +1267,9 @@ def _forward_prefill(
12671267
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
12681268

12691269
output = self._flash_attn_varlen_diff_headdims(
1270-
q=q,
1271-
k=k,
1272-
v=v,
1270+
q,
1271+
k,
1272+
v,
12731273
cu_seqlens_q=prefill_metadata.query_start_loc,
12741274
cu_seqlens_k=prefill_metadata.query_start_loc,
12751275
max_seqlen_q=prefill_metadata.max_prefill_seq_len,

vllm/attention/backends/rocm_aiter_mla.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def get_state_cls() -> Type["AiterMLAState"]:
5353

5454
@dataclass
5555
class AiterMLAMetadata(MLACommonMetadata):
56-
# The following 4 tensors are for current version of AITER MLA
56+
# The following 5 tensors are for current version of AITER MLA
5757
block_table_bound: Optional[torch.Tensor] = None
5858
# The indptr of the paged kv cache, shape: [batch_size + 1]
5959
paged_kv_indptr: Optional[torch.Tensor] = None
@@ -63,6 +63,10 @@ class AiterMLAMetadata(MLACommonMetadata):
6363
# the paged kv cache, shape: [batch_size]
6464
paged_kv_last_page_lens: Optional[torch.Tensor] = None
6565

66+
# This is just to make new AITER MLA API work
67+
# -- MTP support is not added yet.
68+
qo_indptr: Optional[torch.Tensor] = None
69+
6670
@property
6771
def prefill_metadata(self):
6872
prefill_metadata = super().prefill_metadata
@@ -74,6 +78,7 @@ def prefill_metadata(self):
7478
prefill_metadata\
7579
.paged_kv_last_page_lens = self.paged_kv_last_page_lens
7680
prefill_metadata.block_table_bound = self.block_table_bound
81+
prefill_metadata.qo_indptr = self.qo_indptr
7782

7883
# update the cache
7984
self._cached_prefill_metadata = self.__class__(
@@ -93,6 +98,7 @@ def decode_metadata(self):
9398
decode_metadata\
9499
.paged_kv_last_page_lens = self.paged_kv_last_page_lens
95100
decode_metadata.block_table_bound = self.block_table_bound
101+
decode_metadata.qo_indptr = self.qo_indptr
96102

97103
# update the cache
98104
self._cached_decode_metadata = self.__class__(
@@ -136,6 +142,7 @@ def prepare(self):
136142
self.paged_kv_indptr: list[int] = [0]
137143
self.paged_kv_last_page_lens: list[int] = []
138144
self.total_blocks = 0
145+
self.qo_indptr: list[int] = [0]
139146

140147
def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool,
141148
prefix_cache_hit: bool):
@@ -208,6 +215,7 @@ def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int):
208215
self.paged_kv_indices.extend(block_table[:block_table_bound])
209216
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
210217
block_table_bound)
218+
self.qo_indptr.append(self.qo_indptr[-1] + 1)
211219

212220
last_page_len = seq_len % self.block_size
213221
if last_page_len == 0:
@@ -226,6 +234,8 @@ def build(self, seq_lens: list[int], query_lens: list[int],
226234
self.paged_kv_indptr.extend([last_paged_kv_indptr] *
227235
cuda_graph_pad_size)
228236
self.paged_kv_last_page_lens.extend([0] * cuda_graph_pad_size)
237+
last_qo_indptr = self.qo_indptr[-1]
238+
self.qo_indptr.extend([last_qo_indptr] * cuda_graph_pad_size)
229239

230240
# For current version of AITER MLA
231241
if len(self.paged_kv_indptr) > 0:
@@ -245,16 +255,22 @@ def build(self, seq_lens: list[int], query_lens: list[int],
245255
1,
246256
device=device,
247257
dtype=torch.int)
258+
259+
qo_indptr = torch.tensor(self.qo_indptr,
260+
device=device,
261+
dtype=torch.int)
248262
else:
249263
paged_kv_indices_tensor = None
250264
paged_kv_indptr_tensor = None
251265
paged_kv_last_page_lens_tensor = None
252266
block_table_bound_tensor = None
267+
qo_indptr = None
253268

254269
metadata.paged_kv_indptr = paged_kv_indptr_tensor
255270
metadata.paged_kv_indices = paged_kv_indices_tensor
256271
metadata.paged_kv_last_page_lens = paged_kv_last_page_lens_tensor
257272
metadata.block_table_bound = block_table_bound_tensor
273+
metadata.qo_indptr = qo_indptr
258274

259275
return metadata
260276

@@ -263,21 +279,25 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
263279

264280
@contextmanager
265281
def graph_capture(self, max_batch_size: int):
266-
kv_indices, kv_indptr, last_page_lens = get_aiter_mla_metadata(
267-
max_batch_size=max_batch_size,
268-
block_size=self.runner.block_size,
269-
max_block_per_batch=self.runner.get_max_block_per_batch(),
270-
device=self.runner.device)
282+
kv_indices, kv_indptr, last_page_lens, qo_indptr = \
283+
get_aiter_mla_metadata(
284+
max_batch_size=max_batch_size,
285+
block_size=self.runner.block_size,
286+
max_block_per_batch=\
287+
self.runner.get_max_block_per_batch(),
288+
device=self.runner.device)
271289
self._paged_kv_indices_tensor = kv_indices
272290
self._paged_kv_indptr_tensor = kv_indptr
273291
self._paged_kv_last_page_lens_tensor = last_page_lens
292+
self._qo_indptr_tensor = qo_indptr
274293

275294
with super().graph_capture(max_batch_size):
276295
yield
277296

278297
del self._paged_kv_indices_tensor
279298
del self._paged_kv_indptr_tensor
280299
del self._paged_kv_last_page_lens_tensor
300+
del self._qo_indptr_tensor
281301

282302
def graph_capture_get_metadata_for_batch(
283303
self,
@@ -291,10 +311,12 @@ def graph_capture_get_metadata_for_batch(
291311
paged_kv_indices = self._paged_kv_indices_tensor
292312
paged_kv_last_page_lens = self._paged_kv_last_page_lens_tensor[:
293313
batch_size]
314+
qo_indptr = self._qo_indptr_tensor[:batch_size + 1]
294315

295316
metadata.paged_kv_indptr = paged_kv_indptr
296317
metadata.paged_kv_indices = paged_kv_indices
297318
metadata.paged_kv_last_page_lens = paged_kv_last_page_lens
319+
metadata.qo_indptr = qo_indptr
298320

299321
return metadata
300322

@@ -311,6 +333,7 @@ def get_graph_input_buffers(self,
311333
input_buffers[
312334
"paged_kv_last_page_lens"] = attn_metadata.\
313335
decode_metadata.paged_kv_last_page_lens
336+
input_buffers['qo_indptr'] = attn_metadata.qo_indptr
314337

315338
return input_buffers
316339

@@ -330,6 +353,8 @@ def prepare_graph_input_buffers(self,
330353
input_buffers["paged_kv_last_page_lens"].copy_(
331354
attn_metadata.decode_metadata.paged_kv_last_page_lens,
332355
non_blocking=True)
356+
input_buffers["qo_indptr"].copy_(
357+
attn_metadata.decode_metadata.qo_indptr, non_blocking=True)
333358

334359

335360
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
@@ -370,11 +395,9 @@ def _flash_attn_varlen_diff_headdims(
370395
softmax_scale: float, return_softmax_lse: bool,
371396
**kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]:
372397
output = self.flash_attn_varlen_func(
373-
q=q,
374-
k=k,
375-
v=v,
376-
softmax_scale=softmax_scale,
377-
return_lse=return_softmax_lse,
398+
q,
399+
k,
400+
v,
378401
**kwargs,
379402
)
380403

@@ -394,7 +417,7 @@ def _forward_decode(
394417
B = q_nope.shape[0]
395418

396419
q = torch.cat([q_nope, q_pe], dim=-1)
397-
o = torch.zeros(B,
420+
o = torch.empty(B,
398421
self.num_heads,
399422
self.kv_lora_rank,
400423
dtype=q.dtype,
@@ -403,6 +426,8 @@ def _forward_decode(
403426
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
404427

405428
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
429+
attn_metadata.qo_indptr,
430+
attn_metadata.max_query_len,
406431
attn_metadata.paged_kv_indptr,
407432
attn_metadata.paged_kv_indices,
408433
attn_metadata.paged_kv_last_page_lens)

vllm/attention/ops/rocm_aiter_mla.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,17 @@ def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
2020
paged_kv_last_page_lens = torch.full((max_batch_size, ),
2121
block_size,
2222
dtype=torch.int32)
23-
return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens
23+
qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device)
24+
return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr
2425

2526

2627
def aiter_mla_decode_fwd(
2728
q: torch.Tensor,
2829
kv_buffer: torch.Tensor,
2930
o: torch.Tensor,
3031
sm_scale: float,
32+
qo_indptr: torch.Tensor,
33+
max_seqlen_qo: int,
3134
kv_indptr: Optional[torch.Tensor] = None,
3235
kv_indices: Optional[torch.Tensor] = None,
3336
kv_last_page_lens: Optional[torch.Tensor] = None,
@@ -60,9 +63,11 @@ def mla_decode_fwd_impl(
6063
mla_decode_fwd(q,
6164
kv_buffer.view(-1, 1, 1, q.shape[-1]),
6265
o,
66+
qo_indptr,
6367
kv_indptr,
6468
kv_indices,
6569
kv_last_page_lens,
70+
max_seqlen_qo,
6671
sm_scale=sm_scale,
6772
logit_cap=logit_cap)
6873

vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,11 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl(
123123

124124
fmoe_fp8_blockscale_g1u1(out_asm, a1, w1, w2, sorted_token_ids,
125125
sorted_weight_buf, sorted_expert_ids,
126-
num_valid_ids, topk, w1_scale.view(local_E, -1),
127-
w2_scale.view(local_E, -1),
128-
a1_scale.t().contiguous(), *block_shape,
129-
smooth_scale)
126+
num_valid_ids, topk,
127+
a1_scale.t().contiguous(),
128+
w1_scale.view(local_E, -1),
129+
w2_scale.view(local_E,
130+
-1), *block_shape, smooth_scale)
130131

131132
return out_asm
132133

0 commit comments

Comments
 (0)