Skip to content

Commit 6f149db

Browse files
zzzzwwjjwangxiyuan
authored andcommitted
[fix] fix torchair & ci bug
Signed-off-by: zzzzwwjj <[email protected]>
1 parent 95f684f commit 6f149db

File tree

5 files changed

+77
-89
lines changed

5 files changed

+77
-89
lines changed

vllm_ascend/attention/sfa_v1.py

Lines changed: 31 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from dataclasses import dataclass
2-
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Type, TypeVar
2+
from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type,
3+
TypeVar)
34

45
import torch
56
import torch_npu
@@ -12,6 +13,7 @@
1213
from vllm.model_executor.layers.linear import (LinearBase,
1314
UnquantizedLinearMethod)
1415
from vllm.utils import cdiv, round_down
16+
from vllm.v1.attention.backends.utils import AttentionCGSupport
1517

1618
from vllm_ascend.ascend_config import get_ascend_config
1719
from vllm_ascend.attention.attention_v1 import AscendAttentionState
@@ -77,9 +79,9 @@ class ChunkedContextMetadata:
7779
block_table: torch.Tensor
7880
max_query_len: int
7981
max_seq_lens: int
82+
sin: torch.Tensor
83+
cos: torch.Tensor
8084
chunked_context: Optional[ChunkedContextMetadata] = None
81-
sin: torch.Tensor = None
82-
cos: torch.Tensor = None
8385

8486

8587
@dataclass
@@ -91,10 +93,10 @@ class AscendSFADecodeMetadata:
9193
seq_lens: torch.Tensor
9294
max_seq_lens: int
9395
seq_lens_list: list[int]
94-
actual_seq_lengths_q: Optional[torch.Tensor] = None
96+
actual_seq_lengths_q: torch.Tensor
97+
sin: torch.Tensor
98+
cos: torch.Tensor
9599
attn_mask: Optional[torch.Tensor] = None
96-
sin: torch.Tensor = None
97-
cos: torch.Tensor = None
98100

99101

100102
@dataclass
@@ -163,6 +165,9 @@ def split_metadata_for_multistream(
163165

164166

165167
class AscendSFAMetadataBuilder:
168+
# Does this backend/builder support ACL Graphs for attention (default: no).
169+
aclgraph_support: ClassVar[AttentionCGSupport] = \
170+
AttentionCGSupport.NEVER
166171
"""
167172
NOTE: Please read the comment at the top of the file before trying to
168173
understand this class
@@ -292,11 +297,10 @@ def build(
292297
device = self.device
293298

294299
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
295-
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
296-
num_actual_tokens].to(
297-
device,
298-
non_blocking=
299-
True)
300+
slot_mapping = common_attn_metadata.slot_mapping[:
301+
num_actual_tokens].to(
302+
device,
303+
non_blocking=True)
300304
input_positions = common_attn_metadata.positions[:
301305
num_actual_tokens].long(
302306
)
@@ -686,8 +690,7 @@ def _sfa_preprocess(self, hidden_states, kv_cache, attn_metadata,
686690
topk_indices = self.indexer_select(hidden_states_decode,
687691
decode_q_c,
688692
attn_metadata=attn_metadata,
689-
kv_cache=kv_cache,
690-
is_prefill=False)
693+
kv_cache=kv_cache)
691694

692695
query_states = (decode_q_nope, decode_q_pe)
693696
key_states = (decode_k_nope, decode_k_rope)
@@ -775,8 +778,7 @@ def _sfa_preprocess(self, hidden_states, kv_cache, attn_metadata,
775778
topk_indices = self.indexer_select(x=hidden_states_prefill,
776779
qr=prefill_qr,
777780
kv_cache=kv_cache,
778-
attn_metadata=attn_metadata,
779-
is_prefill=True)
781+
attn_metadata=attn_metadata)
780782
query_states = (prefill_q_nope, prefill_q_pe)
781783
key_states = (prefill_k_nope, prefill_k_pe)
782784
prefill_preprocess_res = PrefillSFAPreprocessResult(
@@ -826,45 +828,27 @@ def forward(
826828
query_states=decode_preprocess_res.query_states,
827829
key_states=decode_preprocess_res.key_states,
828830
attn_metadata=attn_metadata,
829-
attention_mask=None,
830-
kv_cache=kv_cache,
831-
topk_indices=decode_preprocess_res.topk_indices,
832-
is_prefill=False,
833-
bsz=decode_preprocess_res.bsz)
831+
topk_indices=decode_preprocess_res.topk_indices)
834832
o_proj_input[:num_decode_tokens] = decode_attn_output
835833

836834
if prefill_preprocess_res is not None:
837835
prefill_attn_output = self.apply_attention_fusion(
838836
query_states=prefill_preprocess_res.query_states,
839837
key_states=prefill_preprocess_res.key_states,
840838
attn_metadata=attn_metadata,
841-
attention_mask=None,
842-
kv_cache=kv_cache,
843-
topk_indices=prefill_preprocess_res.topk_indices,
844-
is_prefill=True,
845-
bsz=None)
839+
topk_indices=prefill_preprocess_res.topk_indices)
846840
o_proj_input[num_decode_tokens:] = prefill_attn_output
847841

848842
output[...] = self.mla_epilog(o_proj_input, absorb=True)
849843
return output
850844

851-
def apply_attention_fusion(
852-
self,
853-
query_states,
854-
key_states,
855-
topk_indices,
856-
attn_metadata: M,
857-
attention_mask: Optional[torch.Tensor] = None,
858-
# actual_seq_qlen: torch.Tensor = None,
859-
# actual_seq_lengths_kv: torch.Tensor = None,
860-
kv_cache: Tuple[torch.Tensor] = None,
861-
is_prefill: bool = True,
862-
bsz: int = None):
845+
def apply_attention_fusion(self, query_states, key_states, topk_indices,
846+
attn_metadata: M):
863847
# repeat k/v heads if n_kv_heads < n_heads
864848
q_nope, q_pe = query_states
865849
k_nope, k_rope = key_states
866850

867-
if is_prefill:
851+
if attn_metadata.prefill is not None:
868852

869853
prefill_metadata = attn_metadata.prefill
870854

@@ -885,7 +869,7 @@ def apply_attention_fusion(
885869
sparse_mode=3,
886870
)
887871

888-
else:
872+
elif attn_metadata.decode is not None:
889873
decode_metadata = attn_metadata.decode
890874

891875
slc_fa_fusion = torch.ops.custom.npu_selected_flash_attention(
@@ -937,14 +921,19 @@ def indexer_select(
937921
qr: torch.Tensor,
938922
kv_cache: Tuple[torch.Tensor],
939923
attn_metadata: M,
940-
is_prefill: bool = True,
941924
):
942-
if is_prefill:
925+
if attn_metadata.prefill is not None:
943926
cos = attn_metadata.prefill.cos
944927
sin = attn_metadata.prefill.sin
945-
else:
928+
actual_seq_lengths_query = attn_metadata.prefill.query_lens
929+
actual_seq_lengths_key = attn_metadata.prefill.seq_lens
930+
block_table = attn_metadata.prefill.block_table
931+
elif attn_metadata.decode is not None:
946932
cos = attn_metadata.decode.cos
947933
sin = attn_metadata.decode.sin
934+
actual_seq_lengths_query = attn_metadata.decode.actual_seq_lengths_q
935+
actual_seq_lengths_key = attn_metadata.decode.seq_lens
936+
block_table = attn_metadata.decode.block_table
948937

949938
cos_q, sin_q = cos, sin
950939
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
@@ -982,17 +971,6 @@ def indexer_select(
982971
k.shape[-1])) # b, s, n, d
983972

984973
weights = self.weights_proj(x)
985-
actual_seq_lengths_query = None
986-
actual_seq_lengths_key = None
987-
block_table = None
988-
if is_prefill:
989-
actual_seq_lengths_query = attn_metadata.prefill.query_lens
990-
actual_seq_lengths_key = attn_metadata.prefill.seq_lens
991-
block_table = attn_metadata.prefill.block_table
992-
else:
993-
actual_seq_lengths_query = attn_metadata.decode.actual_seq_lengths_q
994-
actual_seq_lengths_key = attn_metadata.decode.seq_lens
995-
block_table = attn_metadata.decode.block_table
996974

997975
topk_indices = torch.ops.custom.npu_lightning_indexer(
998976
query=q,

vllm_ascend/patch/worker/patch_common/patch_attention_layer.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
1717
from vllm.platforms import current_platform
1818

19+
from vllm_ascend.utils import vllm_version_is
20+
1921

2022
class AscendAttention(Attention, nn.Module, AttentionLayerBase):
2123
"""Attention layer.
@@ -133,14 +135,23 @@ def __init__(
133135
# weight and activation dtype.
134136
dtype = torch.get_default_dtype()
135137
if attn_backend is None:
136-
self.attn_backend = get_attn_backend(head_size,
137-
dtype,
138-
kv_cache_dtype,
139-
block_size,
140-
is_attention_free,
141-
use_mla=use_mla,
142-
use_sfa=use_sfa,
143-
has_sink=self.has_sink)
138+
if vllm_version_is("0.10.2"):
139+
self.attn_backend = get_attn_backend(head_size,
140+
dtype,
141+
kv_cache_dtype,
142+
block_size,
143+
is_attention_free,
144+
use_mla=use_mla,
145+
use_sfa=use_sfa,
146+
has_sink=self.has_sink)
147+
else:
148+
self.attn_backend = get_attn_backend(head_size,
149+
dtype,
150+
kv_cache_dtype,
151+
block_size,
152+
use_mla=use_mla,
153+
use_sfa=use_sfa,
154+
has_sink=self.has_sink)
144155
else:
145156
self.attn_backend = attn_backend
146157

vllm_ascend/torchair/torchair_model_runner.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
is_310p, get_ascend_soc_version,
4444
AscendSocVersion)
4545
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
46-
import vllm.envs as envs_vllm
4746

4847

4948
class NPUTorchairModelRunner(NPUModelRunner):
@@ -378,7 +377,7 @@ def _get_torchair_lazy_compiled_model(self, batch_size: int):
378377
self.torchair_compiled_model = torch.compile(
379378
self.model,
380379
dynamic=not self.ascend_config.use_sfa,
381-
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
380+
fullgraph=True,
382381
backend=npu_backend)
383382
return self.torchair_compiled_model
384383
else:
@@ -401,7 +400,7 @@ def _get_torchair_lazy_compiled_model(self, batch_size: int):
401400
batch_size] = torchair.inference.cache_compile(
402401
self.model.__dict__[forward_proxy_name],
403402
dynamic=not self.ascend_config.use_sfa,
404-
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
403+
fullgraph=True,
405404
cache_dir=TORCHAIR_CACHE_DIR,
406405
config=config,
407406
ge_cache=False)

vllm_ascend/torchair/torchair_sfa.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,9 @@ class TorchairChunkedContextMetadata:
8181
block_table: torch.Tensor
8282
max_query_len: int
8383
max_seq_lens: int
84+
sin: torch.Tensor
85+
cos: torch.Tensor
8486
chunked_context: Optional[TorchairChunkedContextMetadata] = None
85-
sin: torch.Tensor = None
86-
cos: torch.Tensor = None
8787

8888

8989
@dataclass
@@ -95,10 +95,10 @@ class AscendSFATorchairDecodeMetadata:
9595
seq_lens: torch.Tensor
9696
max_seq_lens: int
9797
seq_lens_list: list[int]
98-
actual_seq_lengths_q: Optional[torch.Tensor] = None
98+
actual_seq_lengths_q: torch.Tensor
99+
sin: torch.Tensor
100+
cos: torch.Tensor
99101
attn_mask: Optional[torch.Tensor] = None
100-
sin: torch.Tensor = None
101-
cos: torch.Tensor = None
102102

103103

104104
@dataclass
@@ -410,11 +410,10 @@ def build(
410410
device = self.device
411411

412412
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
413-
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
414-
num_actual_tokens].to(
415-
device,
416-
non_blocking=
417-
True)
413+
slot_mapping = common_attn_metadata.slot_mapping[:
414+
num_actual_tokens].to(
415+
device,
416+
non_blocking=True)
418417
input_positions = common_attn_metadata.positions[:
419418
num_actual_tokens].long(
420419
)
@@ -984,7 +983,7 @@ def forward(
984983

985984
has_prefill = attn_metadata.is_prefill
986985
has_decode = attn_metadata.is_decode
987-
if has_prefill:
986+
if attn_metadata.prefill is not None:
988987
# num_actual_tokens = attn_metadata.num_actual_tokens
989988
assert attn_metadata.num_decodes is not None and \
990989
attn_metadata.num_prefills is not None and \
@@ -1107,7 +1106,7 @@ def forward(
11071106
output[...] = self.o_proj(attn_output, is_force_scatter=True)
11081107
return output
11091108

1110-
if has_decode:
1109+
elif attn_metadata.decode is not None:
11111110
if envs_ascend.VLLM_ASCEND_ENABLE_MLAPO:
11121111
prep_res = self._sfa_decode_preprocess(hidden_states, kv_cache,
11131112
attn_metadata,
@@ -1227,10 +1226,10 @@ def indexer_select(
12271226
attn_metadata: M,
12281227
is_prefill: bool = True,
12291228
):
1230-
if is_prefill:
1229+
if attn_metadata.prefill is not None:
12311230
cos = attn_metadata.prefill.cos
12321231
sin = attn_metadata.prefill.sin
1233-
else:
1232+
elif attn_metadata.decode is not None:
12341233
cos = attn_metadata.decode.cos
12351234
sin = attn_metadata.decode.sin
12361235

@@ -1281,14 +1280,13 @@ def indexer_select(
12811280
actual_seq_lengths_query = None
12821281
actual_seq_lengths_key = None
12831282
block_table = None
1284-
if is_prefill:
1283+
if attn_metadata.prefill is not None:
12851284
actual_seq_lengths_query = attn_metadata.prefill.query_lens
12861285
actual_seq_lengths_key = attn_metadata.prefill.seq_lens
12871286

12881287
block_table = attn_metadata.prefill.block_table
1289-
else:
1288+
elif attn_metadata.decode is not None:
12901289
actual_seq_lengths_query = attn_metadata.decode.actual_seq_lengths_q
1291-
# actual_seq_lengths_query = self.actual_seq_length
12921290
actual_seq_lengths_key = attn_metadata.decode.seq_lens.to(
12931291
torch.int32)
12941292

vllm_ascend/worker/model_runner_v1.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2511,7 +2511,7 @@ def profile_run(self) -> None:
25112511
# MC2 will consume additional NPU memory.
25122512
# Therefore, we need to run the MC2 path once here to complete its initialization,
25132513
# allowing vLLM to correctly estimate the maximum memory required.
2514-
if self._select_moe_comm_method(
2514+
if not self.ascend_config.torchair_graph_config.enabled and self._select_moe_comm_method(
25152515
self.mc2_tokens_capacity,
25162516
with_prefill=True) == MoECommType.MC2:
25172517
self._dummy_run(self.mc2_tokens_capacity, with_prefill=True)
@@ -2684,10 +2684,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
26842684
self.kv_cache_config = kv_cache_config
26852685
self.initialize_attn_backend(kv_cache_config)
26862686
self.use_hybrid_blocks = (len(self.attn_groups) > 1)
2687-
# NOTE: Currently, we determine whether we need `num_accepted_tokens` through `GDNAttentionMetadataBuilder`.
2687+
# NOTE: Currently, we determine whether we need `num_accepted_tokens` through `MambaSpec`.
26882688
self.need_accepted_tokens = any([
2689-
isinstance(attn_group[0].metadata_builder,
2690-
GDNAttentionMetadataBuilder)
2689+
isinstance(attn_group[0].kv_cache_spec, MambaSpec)
26912690
for attn_group in self.attn_groups
26922691
])
26932692
self.may_reinitialize_input_batch(kv_cache_config)
@@ -2721,10 +2720,13 @@ def initialize_kv_cache_tensors_deepseek_sfa(
27212720
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
27222721

27232722
kv_caches: Dict[str, torch.Tensor] = {}
2724-
for kv_cache_spec, kv_cache_group in self._kv_cache_spec_attn_group_iterator(
2725-
):
2726-
attn_backend = kv_cache_group.backend
2727-
for layer_name in kv_cache_group.layer_names:
2723+
for group in self._kv_cache_spec_attn_group_iterator_dispatcher():
2724+
if vllm_version_is("0.10.2"):
2725+
kv_cache_spec, group = group
2726+
else:
2727+
kv_cache_spec = group.kv_cache_spec
2728+
attn_backend = group.backend
2729+
for layer_name in group.layer_names:
27282730
if layer_name in self.runner_only_attn_layers:
27292731
continue
27302732
tensor_size = kv_cache_sizes[layer_name]

0 commit comments

Comments
 (0)