|
199 | 199 | from dataclasses import dataclass |
200 | 200 | from itertools import accumulate |
201 | 201 | from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple, |
202 | | - Type) |
| 202 | + Type, TypeVar) |
203 | 203 |
|
204 | 204 | import torch |
205 | 205 | from compressed_tensors.quantization import QuantizationStrategy |
|
209 | 209 | from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, |
210 | 210 | AttentionMetadata, |
211 | 211 | AttentionMetadataBuilder, |
212 | | - AttentionState, MLAAttentionImpl, |
213 | | - T) |
| 212 | + AttentionState, MLAAttentionImpl) |
214 | 213 | from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, |
215 | 214 | compute_slot_mapping_start_idx, |
216 | 215 | get_flash_attn_version, |
@@ -723,6 +722,9 @@ def advance_step(self, |
723 | 722 | block_tables=self.block_tables) |
724 | 723 |
|
725 | 724 |
|
| 725 | +T = TypeVar("T", bound=MLACommonMetadata) |
| 726 | + |
| 727 | + |
726 | 728 | class MLACommonMetadataBuilder(AttentionMetadataBuilder[MLACommonMetadata]): |
727 | 729 | """ |
728 | 730 | NOTE: Please read the comment at the top of the file before trying to |
@@ -1268,12 +1270,15 @@ def _compute_prefill_context( |
1268 | 1270 | assert prefill_metadata.context_chunk_cu_seq_lens is not None |
1269 | 1271 | assert prefill_metadata.context_chunk_starts is not None |
1270 | 1272 | assert prefill_metadata.context_chunk_max_seq_lens is not None |
1271 | | - # assert prefill_metadata.block_tables is not None |
1272 | 1273 | assert prefill_metadata.context_lens_tensor is not None |
1273 | 1274 |
|
1274 | 1275 | output = None |
1275 | 1276 | iters = len(prefill_metadata.context_chunk_seq_tot) |
1276 | | - assert hasattr(attn_metadata, "chunked_prefill_workspace") |
| 1277 | + |
| 1278 | + # Fetch from attn_metadata directly, since it late bound by |
| 1279 | + # MLAAttentionState, grabbing it directly `attn_metadata` can avoid |
| 1280 | + # any weirdness around prefill_metadata caching |
| 1281 | + assert attn_metadata.chunked_prefill_workspace is not None |
1277 | 1282 | workspace = attn_metadata.chunked_prefill_workspace |
1278 | 1283 |
|
1279 | 1284 | for i in range(iters): |
@@ -1345,9 +1350,8 @@ def _forward_prefill( |
1345 | 1350 | kv_c_normed: torch.Tensor, |
1346 | 1351 | k_pe: torch.Tensor, |
1347 | 1352 | kv_c_and_k_pe_cache: torch.Tensor, |
1348 | | - attn_metadata: T, |
| 1353 | + attn_metadata: MLACommonMetadata, |
1349 | 1354 | ) -> torch.Tensor: |
1350 | | - assert isinstance(attn_metadata, MLACommonMetadata) |
1351 | 1355 |
|
1352 | 1356 | prefill_metadata = attn_metadata.prefill_metadata |
1353 | 1357 | assert prefill_metadata is not None |
|
0 commit comments