Skip to content

Commit 28464b5

Browse files
mypy pass
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 3a0ae51 commit 28464b5

File tree

2 files changed

+18
-19
lines changed

2 files changed

+18
-19
lines changed

vllm/attention/backends/mla/common.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@
199199
from dataclasses import dataclass
200200
from itertools import accumulate
201201
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple,
202-
Type)
202+
Type, TypeVar)
203203

204204
import torch
205205
from compressed_tensors.quantization import QuantizationStrategy
@@ -209,8 +209,7 @@
209209
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
210210
AttentionMetadata,
211211
AttentionMetadataBuilder,
212-
AttentionState, MLAAttentionImpl,
213-
T)
212+
AttentionState, MLAAttentionImpl)
214213
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
215214
compute_slot_mapping_start_idx,
216215
get_flash_attn_version,
@@ -723,6 +722,9 @@ def advance_step(self,
723722
block_tables=self.block_tables)
724723

725724

725+
T = TypeVar("T", bound=MLACommonMetadata)
726+
727+
726728
class MLACommonMetadataBuilder(AttentionMetadataBuilder[MLACommonMetadata]):
727729
"""
728730
NOTE: Please read the comment at the top of the file before trying to
@@ -1268,12 +1270,15 @@ def _compute_prefill_context(
12681270
assert prefill_metadata.context_chunk_cu_seq_lens is not None
12691271
assert prefill_metadata.context_chunk_starts is not None
12701272
assert prefill_metadata.context_chunk_max_seq_lens is not None
1271-
# assert prefill_metadata.block_tables is not None
12721273
assert prefill_metadata.context_lens_tensor is not None
12731274

12741275
output = None
12751276
iters = len(prefill_metadata.context_chunk_seq_tot)
1276-
assert hasattr(attn_metadata, "chunked_prefill_workspace")
1277+
1278+
# Fetch from attn_metadata directly, since it late bound by
1279+
# MLAAttentionState, grabbing it directly `attn_metadata` can avoid
1280+
# any weirdness around prefill_metadata caching
1281+
assert attn_metadata.chunked_prefill_workspace is not None
12771282
workspace = attn_metadata.chunked_prefill_workspace
12781283

12791284
for i in range(iters):
@@ -1345,9 +1350,8 @@ def _forward_prefill(
13451350
kv_c_normed: torch.Tensor,
13461351
k_pe: torch.Tensor,
13471352
kv_c_and_k_pe_cache: torch.Tensor,
1348-
attn_metadata: T,
1353+
attn_metadata: MLACommonMetadata,
13491354
) -> torch.Tensor:
1350-
assert isinstance(attn_metadata, MLACommonMetadata)
13511355

13521356
prefill_metadata = attn_metadata.prefill_metadata
13531357
assert prefill_metadata is not None

vllm/attention/backends/triton_mla.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,14 @@ def _forward_decode(
8080
dtype=q.dtype,
8181
device=q.device)
8282

83+
num_kv_splits = 4 # TODO: heuristic
84+
8385
# TODO(lucas) Allocate ahead of time
8486
attn_logits = torch.empty(
8587
(
8688
B,
8789
self.num_heads,
88-
4, #attn_metadata.num_kv_splits,
90+
num_kv_splits,
8991
# NOTE(lucas) idk why the +1 is here but sglang has it so we
9092
# just mirror that
9193
self.kv_lora_rank + 1,
@@ -100,16 +102,9 @@ def _forward_decode(
100102
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
101103

102104
# Run MQA
103-
decode_attention_fwd(
104-
q,
105-
kv_c_and_k_pe_cache,
106-
kv_c_cache,
107-
o,
108-
decode_meta.block_tables,
109-
decode_meta.seq_lens_tensor,
110-
attn_logits,
111-
4,
112-
self.scale, #attn_metadata.num_kv_splits
113-
PAGE_SIZE)
105+
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
106+
decode_meta.block_tables,
107+
decode_meta.seq_lens_tensor, attn_logits,
108+
num_kv_splits, self.scale, PAGE_SIZE)
114109

115110
return self._v_up_proj_and_o_proj(o)

0 commit comments

Comments
 (0)