Skip to content

Commit a99579e

Browse files
linfeng-yuanwangxiyuan
authored andcommitted
adapt to new custom_ops interface
Signed-off-by: linfeng-yuan <[email protected]>
1 parent 8e9abf7 commit a99579e

File tree

3 files changed

+15
-15
lines changed

3 files changed

+15
-15
lines changed

vllm_ascend/attention/sfa_v1.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -852,13 +852,13 @@ def apply_attention_fusion(self, query_states, key_states, topk_indices,
852852

853853
prefill_metadata = attn_metadata.prefill
854854

855-
slc_fa_fusion = torch.ops.custom.npu_selected_flash_attention(
855+
slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention(
856856
query=q_nope,
857857
key=k_nope,
858858
value=k_nope,
859-
selected_indices=topk_indices,
859+
sparse_indices=topk_indices,
860860
scale_value=self.scale,
861-
selected_block_size=1,
861+
sparse_block_size=1,
862862
block_table=prefill_metadata.block_table,
863863
actual_seq_lengths_query=prefill_metadata.query_lens,
864864
actual_seq_lengths_kv=prefill_metadata.seq_lens,
@@ -872,13 +872,13 @@ def apply_attention_fusion(self, query_states, key_states, topk_indices,
872872
elif attn_metadata.decode is not None:
873873
decode_metadata = attn_metadata.decode
874874

875-
slc_fa_fusion = torch.ops.custom.npu_selected_flash_attention(
875+
slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention(
876876
query=q_nope,
877877
key=k_nope,
878878
value=k_nope,
879-
selected_indices=topk_indices,
879+
sparse_indices=topk_indices,
880880
scale_value=self.scale,
881-
selected_block_size=1,
881+
sparse_block_size=1,
882882
block_table=attn_metadata.decode.block_table,
883883
actual_seq_lengths_query=decode_metadata.actual_seq_lengths_q,
884884
actual_seq_lengths_kv=decode_metadata.seq_lens,
@@ -981,6 +981,6 @@ def indexer_select(
981981
block_table=block_table,
982982
layout_query="TND",
983983
layout_key="PA_BSND",
984-
selected_count=2048,
984+
sparse_count=2048,
985985
sparse_mode=3)
986986
return topk_indices

vllm_ascend/torchair/torchair_sfa.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,13 +1070,13 @@ def forward(
10701070
k_nope, k_rope = key_states
10711071
prefill_metadata = attn_metadata.prefill
10721072

1073-
slc_fa_fusion = torch.ops.custom.npu_selected_flash_attention(
1073+
slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention(
10741074
query=q_nope,
10751075
key=k_nope,
10761076
value=k_nope,
1077-
selected_indices=topk_indices,
1077+
sparse_indices=topk_indices,
10781078
scale_value=self.scale,
1079-
selected_block_size=1,
1079+
sparse_block_size=1,
10801080
block_table=prefill_metadata.block_table,
10811081
actual_seq_lengths_query=prefill_metadata.query_lens,
10821082
actual_seq_lengths_kv=prefill_metadata.seq_lens,
@@ -1175,13 +1175,13 @@ def forward(
11751175
k_nope, k_rope = key_states
11761176

11771177
decode_metadata = attn_metadata.decode
1178-
slc_fa_fusion = torch.ops.custom.npu_selected_flash_attention(
1178+
slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention(
11791179
query=q_nope,
11801180
key=k_nope,
11811181
value=k_nope,
1182-
selected_indices=topk_indices,
1182+
sparse_indices=topk_indices,
11831183
scale_value=self.scale,
1184-
selected_block_size=1,
1184+
sparse_block_size=1,
11851185
block_table=attn_metadata.decode.block_table,
11861186
actual_seq_lengths_query=decode_metadata.actual_seq_lengths_q,
11871187
actual_seq_lengths_kv=decode_metadata.seq_lens,
@@ -1292,7 +1292,7 @@ def indexer_select(
12921292
block_table=block_table,
12931293
layout_query="TND",
12941294
layout_key="PA_BSND",
1295-
selected_count=2048,
1295+
sparse_count=2048,
12961296
sparse_mode=3)
12971297
return topk_indices
12981298

vllm_ascend/worker/worker_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(
9494
import custom_ops # type: ignore[import-untyped] # noqa
9595
logger.info(
9696
"custom_ops module loaded successfully. Custom operators like "
97-
"torch.ops.custom.npu_selected_flash_attention are now available."
97+
"torch.ops.custom.npu_sparse_flash_attention are now available."
9898
)
9999

100100
super().__init__(vllm_config=vllm_config,

0 commit comments

Comments
 (0)