Skip to content

Commit 85b72cb

Browse files
authored
Revert "[BugFix][AMD] Compatible patch for latest AITER(05/07/2025)" (#17910)
1 parent 6e5595c commit 85b72cb

File tree

4 files changed

+23
-54
lines changed

4 files changed

+23
-54
lines changed

vllm/attention/backends/mla/common.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,9 +1213,9 @@ def _compute_prefill_context(
12131213

12141214
attn_output, attn_softmax_lse = \
12151215
self._flash_attn_varlen_diff_headdims(
1216-
q,
1217-
k,
1218-
v,
1216+
q=q,
1217+
k=k,
1218+
v=v,
12191219
cu_seqlens_q=prefill_metadata.query_start_loc,
12201220
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
12211221
max_seqlen_q=prefill_metadata.max_query_len,
@@ -1267,9 +1267,9 @@ def _forward_prefill(
12671267
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
12681268

12691269
output = self._flash_attn_varlen_diff_headdims(
1270-
q,
1271-
k,
1272-
v,
1270+
q=q,
1271+
k=k,
1272+
v=v,
12731273
cu_seqlens_q=prefill_metadata.query_start_loc,
12741274
cu_seqlens_k=prefill_metadata.query_start_loc,
12751275
max_seqlen_q=prefill_metadata.max_prefill_seq_len,

vllm/attention/backends/rocm_aiter_mla.py

Lines changed: 12 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def get_state_cls() -> Type["AiterMLAState"]:
5353

5454
@dataclass
5555
class AiterMLAMetadata(MLACommonMetadata):
56-
# The following 5 tensors are for current version of AITER MLA
56+
# The following 4 tensors are for current version of AITER MLA
5757
block_table_bound: Optional[torch.Tensor] = None
5858
# The indptr of the paged kv cache, shape: [batch_size + 1]
5959
paged_kv_indptr: Optional[torch.Tensor] = None
@@ -63,10 +63,6 @@ class AiterMLAMetadata(MLACommonMetadata):
6363
# the paged kv cache, shape: [batch_size]
6464
paged_kv_last_page_lens: Optional[torch.Tensor] = None
6565

66-
# This is just to make new AITER MLA API work
67-
# -- MTP support is not added yet.
68-
qo_indptr: Optional[torch.Tensor] = None
69-
7066
@property
7167
def prefill_metadata(self):
7268
prefill_metadata = super().prefill_metadata
@@ -78,7 +74,6 @@ def prefill_metadata(self):
7874
prefill_metadata\
7975
.paged_kv_last_page_lens = self.paged_kv_last_page_lens
8076
prefill_metadata.block_table_bound = self.block_table_bound
81-
prefill_metadata.qo_indptr = self.qo_indptr
8277

8378
# update the cache
8479
self._cached_prefill_metadata = self.__class__(
@@ -98,7 +93,6 @@ def decode_metadata(self):
9893
decode_metadata\
9994
.paged_kv_last_page_lens = self.paged_kv_last_page_lens
10095
decode_metadata.block_table_bound = self.block_table_bound
101-
decode_metadata.qo_indptr = self.qo_indptr
10296

10397
# update the cache
10498
self._cached_decode_metadata = self.__class__(
@@ -142,7 +136,6 @@ def prepare(self):
142136
self.paged_kv_indptr: list[int] = [0]
143137
self.paged_kv_last_page_lens: list[int] = []
144138
self.total_blocks = 0
145-
self.qo_indptr: list[int] = [0]
146139

147140
def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool,
148141
prefix_cache_hit: bool):
@@ -215,7 +208,6 @@ def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int):
215208
self.paged_kv_indices.extend(block_table[:block_table_bound])
216209
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
217210
block_table_bound)
218-
self.qo_indptr.append(self.qo_indptr[-1] + 1)
219211

220212
last_page_len = seq_len % self.block_size
221213
if last_page_len == 0:
@@ -234,8 +226,6 @@ def build(self, seq_lens: list[int], query_lens: list[int],
234226
self.paged_kv_indptr.extend([last_paged_kv_indptr] *
235227
cuda_graph_pad_size)
236228
self.paged_kv_last_page_lens.extend([0] * cuda_graph_pad_size)
237-
last_qo_indptr = self.qo_indptr[-1]
238-
self.qo_indptr.extend([last_qo_indptr] * cuda_graph_pad_size)
239229

240230
# For current version of AITER MLA
241231
if len(self.paged_kv_indptr) > 0:
@@ -255,22 +245,16 @@ def build(self, seq_lens: list[int], query_lens: list[int],
255245
1,
256246
device=device,
257247
dtype=torch.int)
258-
259-
qo_indptr = torch.tensor(self.qo_indptr,
260-
device=device,
261-
dtype=torch.int)
262248
else:
263249
paged_kv_indices_tensor = None
264250
paged_kv_indptr_tensor = None
265251
paged_kv_last_page_lens_tensor = None
266252
block_table_bound_tensor = None
267-
qo_indptr = None
268253

269254
metadata.paged_kv_indptr = paged_kv_indptr_tensor
270255
metadata.paged_kv_indices = paged_kv_indices_tensor
271256
metadata.paged_kv_last_page_lens = paged_kv_last_page_lens_tensor
272257
metadata.block_table_bound = block_table_bound_tensor
273-
metadata.qo_indptr = qo_indptr
274258

275259
return metadata
276260

@@ -279,25 +263,21 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
279263

280264
@contextmanager
281265
def graph_capture(self, max_batch_size: int):
282-
kv_indices, kv_indptr, last_page_lens, qo_indptr = \
283-
get_aiter_mla_metadata(
284-
max_batch_size=max_batch_size,
285-
block_size=self.runner.block_size,
286-
max_block_per_batch=\
287-
self.runner.get_max_block_per_batch(),
288-
device=self.runner.device)
266+
kv_indices, kv_indptr, last_page_lens = get_aiter_mla_metadata(
267+
max_batch_size=max_batch_size,
268+
block_size=self.runner.block_size,
269+
max_block_per_batch=self.runner.get_max_block_per_batch(),
270+
device=self.runner.device)
289271
self._paged_kv_indices_tensor = kv_indices
290272
self._paged_kv_indptr_tensor = kv_indptr
291273
self._paged_kv_last_page_lens_tensor = last_page_lens
292-
self._qo_indptr_tensor = qo_indptr
293274

294275
with super().graph_capture(max_batch_size):
295276
yield
296277

297278
del self._paged_kv_indices_tensor
298279
del self._paged_kv_indptr_tensor
299280
del self._paged_kv_last_page_lens_tensor
300-
del self._qo_indptr_tensor
301281

302282
def graph_capture_get_metadata_for_batch(
303283
self,
@@ -311,12 +291,10 @@ def graph_capture_get_metadata_for_batch(
311291
paged_kv_indices = self._paged_kv_indices_tensor
312292
paged_kv_last_page_lens = self._paged_kv_last_page_lens_tensor[:
313293
batch_size]
314-
qo_indptr = self._qo_indptr_tensor[:batch_size + 1]
315294

316295
metadata.paged_kv_indptr = paged_kv_indptr
317296
metadata.paged_kv_indices = paged_kv_indices
318297
metadata.paged_kv_last_page_lens = paged_kv_last_page_lens
319-
metadata.qo_indptr = qo_indptr
320298

321299
return metadata
322300

@@ -333,7 +311,6 @@ def get_graph_input_buffers(self,
333311
input_buffers[
334312
"paged_kv_last_page_lens"] = attn_metadata.\
335313
decode_metadata.paged_kv_last_page_lens
336-
input_buffers['qo_indptr'] = attn_metadata.qo_indptr
337314

338315
return input_buffers
339316

@@ -353,8 +330,6 @@ def prepare_graph_input_buffers(self,
353330
input_buffers["paged_kv_last_page_lens"].copy_(
354331
attn_metadata.decode_metadata.paged_kv_last_page_lens,
355332
non_blocking=True)
356-
input_buffers["qo_indptr"].copy_(
357-
attn_metadata.decode_metadata.qo_indptr, non_blocking=True)
358333

359334

360335
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
@@ -395,9 +370,11 @@ def _flash_attn_varlen_diff_headdims(
395370
softmax_scale: float, return_softmax_lse: bool,
396371
**kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]:
397372
output = self.flash_attn_varlen_func(
398-
q,
399-
k,
400-
v,
373+
q=q,
374+
k=k,
375+
v=v,
376+
softmax_scale=softmax_scale,
377+
return_lse=return_softmax_lse,
401378
**kwargs,
402379
)
403380

@@ -417,7 +394,7 @@ def _forward_decode(
417394
B = q_nope.shape[0]
418395

419396
q = torch.cat([q_nope, q_pe], dim=-1)
420-
o = torch.empty(B,
397+
o = torch.zeros(B,
421398
self.num_heads,
422399
self.kv_lora_rank,
423400
dtype=q.dtype,
@@ -426,8 +403,6 @@ def _forward_decode(
426403
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
427404

428405
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
429-
attn_metadata.qo_indptr,
430-
attn_metadata.max_query_len,
431406
attn_metadata.paged_kv_indptr,
432407
attn_metadata.paged_kv_indices,
433408
attn_metadata.paged_kv_last_page_lens)

vllm/attention/ops/rocm_aiter_mla.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,14 @@ def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
2020
paged_kv_last_page_lens = torch.full((max_batch_size, ),
2121
block_size,
2222
dtype=torch.int32)
23-
qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device)
24-
return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr
23+
return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens
2524

2625

2726
def aiter_mla_decode_fwd(
2827
q: torch.Tensor,
2928
kv_buffer: torch.Tensor,
3029
o: torch.Tensor,
3130
sm_scale: float,
32-
qo_indptr: torch.Tensor,
33-
max_seqlen_qo: int,
3431
kv_indptr: Optional[torch.Tensor] = None,
3532
kv_indices: Optional[torch.Tensor] = None,
3633
kv_last_page_lens: Optional[torch.Tensor] = None,
@@ -63,11 +60,9 @@ def mla_decode_fwd_impl(
6360
mla_decode_fwd(q,
6461
kv_buffer.view(-1, 1, 1, q.shape[-1]),
6562
o,
66-
qo_indptr,
6763
kv_indptr,
6864
kv_indices,
6965
kv_last_page_lens,
70-
max_seqlen_qo,
7166
sm_scale=sm_scale,
7267
logit_cap=logit_cap)
7368

vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,10 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl(
123123

124124
fmoe_fp8_blockscale_g1u1(out_asm, a1, w1, w2, sorted_token_ids,
125125
sorted_weight_buf, sorted_expert_ids,
126-
num_valid_ids, topk,
127-
a1_scale.t().contiguous(),
128-
w1_scale.view(local_E, -1),
129-
w2_scale.view(local_E,
130-
-1), *block_shape, smooth_scale)
126+
num_valid_ids, topk, w1_scale.view(local_E, -1),
127+
w2_scale.view(local_E, -1),
128+
a1_scale.t().contiguous(), *block_shape,
129+
smooth_scale)
131130

132131
return out_asm
133132

0 commit comments

Comments
 (0)