Skip to content

Commit 8e9abf7

Browse files
linfeng-yuanwangxiyuan
authored andcommitted
fix ci
Signed-off-by: linfeng-yuan <[email protected]>
1 parent 6f149db commit 8e9abf7

File tree

5 files changed

+17
-19
lines changed

5 files changed

+17
-19
lines changed

vllm_ascend/attention/sfa_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -796,7 +796,7 @@ def _sfa_preprocess(self, hidden_states, kv_cache, attn_metadata,
796796
def forward(
797797
self,
798798
hidden_states: torch.Tensor, # query in unified attn
799-
kv_cache: Tuple[torch.Tensor],
799+
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
800800
attn_metadata: M,
801801
need_gather_q_kv: bool = False,
802802
output: Optional[torch.Tensor] = None,
@@ -919,7 +919,7 @@ def indexer_select(
919919
self,
920920
x: torch.Tensor,
921921
qr: torch.Tensor,
922-
kv_cache: Tuple[torch.Tensor],
922+
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
923923
attn_metadata: M,
924924
):
925925
if attn_metadata.prefill is not None:

vllm_ascend/patch/worker/patch_common/patch_attention_selector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def _cached_get_attn_backend(
108108
return resolve_obj_by_qualname(attention_cls)
109109
else:
110110

111-
def get_attn_backend(
111+
def get_attn_backend( # type: ignore[misc]
112112
head_size: int,
113113
dtype: torch.dtype,
114114
kv_cache_dtype: Optional[str],

vllm_ascend/torchair/models/torchair_deepseek_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -890,7 +890,7 @@ def __init__(
890890
attn_cls = TorchairDeepseekV2SFAAttention
891891
self.use_sfa = True
892892
else:
893-
attn_cls = TorchairDeepseekV2MLAAttention
893+
attn_cls = TorchairDeepseekV2MLAAttention # type: ignore[assignment]
894894
else:
895895
attn_cls = DeepseekV2Attention
896896
self.self_attn = attn_cls(

vllm_ascend/torchair/torchair_sfa.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -971,7 +971,7 @@ def _sfa_decode_preprocess(self, hidden_states, kv_cache, attn_metadata,
971971
def forward(
972972
self,
973973
hidden_states: torch.Tensor, # query in unified attn
974-
kv_cache: Tuple[torch.Tensor],
974+
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
975975
attn_metadata: M,
976976
need_gather_q_kv: bool = False,
977977
output: Optional[torch.Tensor] = None,
@@ -981,21 +981,12 @@ def forward(
981981
# Profiling run.
982982
return output
983983

984-
has_prefill = attn_metadata.is_prefill
985-
has_decode = attn_metadata.is_decode
984+
986985
if attn_metadata.prefill is not None:
987-
# num_actual_tokens = attn_metadata.num_actual_tokens
988986
assert attn_metadata.num_decodes is not None and \
989987
attn_metadata.num_prefills is not None and \
990988
attn_metadata.num_decode_tokens is not None
991-
# num_decode_tokens = attn_metadata.num_decode_tokens
992-
# Inputs and outputs may be padded for CUDA graphs
993-
# has_decode = attn_metadata.num_decodes > 0
994-
has_prefill = attn_metadata.num_prefills > 0
995-
# num_decode_tokens = attn_metadata.num_decode_tokens
996-
# num_actual_tokens = attn_metadata.num_actual_tokens
997-
998-
# output_padded = output
989+
999990
bsz = 1
1000991

1001992
hidden_states_prefill = hidden_states
@@ -1222,7 +1213,7 @@ def indexer_select(
12221213
self,
12231214
x: torch.Tensor,
12241215
qr: torch.Tensor,
1225-
kv_cache: Tuple[torch.Tensor],
1216+
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
12261217
attn_metadata: M,
12271218
is_prefill: bool = True,
12281219
):

vllm_ascend/worker/worker_v1.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from vllm.v1.worker.worker_base import WorkerBase
4444

4545
import vllm_ascend.envs as envs_ascend
46-
from vllm_ascend.ascend_config import init_ascend_config
46+
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
4747
from vllm_ascend.device_allocator.camem import CaMemAllocator
4848
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
4949
from vllm_ascend.platform import NPUPlatform
@@ -88,7 +88,14 @@ def __init__(
8888
# init ascend config and soc version
8989
init_ascend_config(vllm_config)
9090
init_ascend_soc_version()
91-
import custom_ops # noqa
91+
if get_ascend_config().use_sfa:
92+
# Direct import instead of using try_register_lib to ensure proper error handling when
93+
# custom_ops is necessary but not available (e.g., in DeepSeek v3.2 deployments)
94+
import custom_ops # type: ignore[import-untyped] # noqa
95+
logger.info(
96+
"custom_ops module loaded successfully. Custom operators like "
97+
"torch.ops.custom.npu_selected_flash_attention are now available."
98+
)
9299

93100
super().__init__(vllm_config=vllm_config,
94101
local_rank=local_rank,

0 commit comments

Comments
 (0)