Skip to content

Commit 8cac148

Browse files
committed
format
1 parent 178436b commit 8cac148

File tree

4 files changed

+31
-29
lines changed

4 files changed

+31
-29
lines changed

vllm_ascend/attention/sfa_v1.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,15 @@
2828
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE, MLAPO_MAX_SUPPORTED_TOKENS
2929
from vllm_ascend.attention.utils import (
3030
AscendCommonAttentionMetadata,
31+
AscendLightningIndexerMetadata,
3132
ascend_chunked_prefill_workspace_size,
3233
enable_cp,
34+
get_index_of_skipped_queries_numpy,
35+
get_sfa_skip_indices,
3336
maybe_save_kv_layer_to_connector,
3437
trans_rope_weight,
3538
transdata,
3639
wait_for_kv_layer_from_connector,
37-
get_sfa_skip_indices,
38-
get_index_of_skipped_queries_numpy,
39-
AscendLightningIndexerMetadata
4040
)
4141
from vllm_ascend.device.device_op import DeviceOperator
4242
from vllm_ascend.distributed.utils import all_gather_async
@@ -55,8 +55,8 @@
5555
dispose_layer,
5656
enable_dsa_cp,
5757
enable_dsa_cp_with_layer_shard,
58-
enable_lightning_indexer_skip,
5958
enable_dsa_cp_with_o_proj_tp,
59+
enable_lightning_indexer_skip,
6060
get_weight_prefetch_method,
6161
maybe_trans_nz,
6262
)
@@ -244,7 +244,7 @@ def build(
244244

245245
cum_query_lens = common_attn_metadata.query_start_loc[1 : num_reqs + 1]
246246
seq_lens = common_attn_metadata.seq_lens[:num_reqs]
247-
247+
248248
my_query_start_loc = common_attn_metadata.query_start_loc[: num_reqs + 1]
249249
tokens = my_query_start_loc[1:] - my_query_start_loc[:-1]
250250

@@ -332,9 +332,8 @@ def build(
332332
top_k_indices_skip_li_query = None
333333
skip = False
334334
if enable_lightning_indexer_skip():
335-
336-
li_reorder_indices, li_cum_query_lens, li_seq_lens, li_skiped_query_mask, num_of_non_skip_tokens = get_sfa_skip_indices(
337-
seq_lens-tokens, tokens
335+
li_reorder_indices, li_cum_query_lens, li_seq_lens, li_skiped_query_mask, num_of_non_skip_tokens = (
336+
get_sfa_skip_indices(seq_lens - tokens, tokens)
338337
)
339338
skip = num_of_non_skip_tokens is not None
340339

@@ -343,7 +342,7 @@ def build(
343342
li_cum_query_lens, li_seq_lens, num_reqs, 2048
344343
)
345344
common_attn_metadata.lightning_indexer_metadata = AscendLightningIndexerMetadata(
346-
li_reorder_indices=torch.from_numpy(li_reorder_indices)
345+
li_reorder_indices=torch.from_numpy(li_reorder_indices)
347346
.pin_memory()
348347
.to(dtype=torch.int32, device=self.device, non_blocking=True),
349348
li_cum_query_lens=torch.from_numpy(li_cum_query_lens)
@@ -358,7 +357,7 @@ def build(
358357
top_k_indices_of_skipped_queries=torch.from_numpy(top_k_indices_of_skipped_queries_numpy)
359358
.pin_memory()
360359
.to(dtype=torch.int32, device=self.device, non_blocking=True),
361-
num_of_non_skip_tokens = num_of_non_skip_tokens
360+
num_of_non_skip_tokens=num_of_non_skip_tokens,
362361
)
363362
li_reorder_indices = common_attn_metadata.lightning_indexer_metadata.li_reorder_indices
364363
input_positions_pad = torch.zeros_like(input_positions)
@@ -374,7 +373,9 @@ def build(
374373
slot_mapping = slot_mapping_pad
375374
input_positions = input_positions_pad
376375
cos, sin = get_cos_and_sin_mla(input_positions, True)
377-
top_k_indices_skip_li_query = common_attn_metadata.lightning_indexer_metadata.top_k_indices_of_skipped_queries
376+
top_k_indices_skip_li_query = (
377+
common_attn_metadata.lightning_indexer_metadata.top_k_indices_of_skipped_queries
378+
)
378379

379380
return self.metadata_cls( # type: ignore
380381
num_input_tokens=common_attn_metadata.num_input_tokens,
@@ -389,10 +390,10 @@ def build(
389390
sin=sin[:num_input_tokens],
390391
cos=cos[:num_input_tokens],
391392
dsa_cp_context=dsa_cp_context,
392-
num_actual_seqs = num_reqs,
393-
top_k_indices_skip_li_query = top_k_indices_skip_li_query,
394-
non_skip_num_actual_tokens = num_of_non_skip_tokens,
395-
skip = skip
393+
num_actual_seqs=num_reqs,
394+
top_k_indices_skip_li_query=top_k_indices_skip_li_query,
395+
non_skip_num_actual_tokens=num_of_non_skip_tokens,
396+
skip=skip,
396397
)
397398

398399
def build_for_graph_capture(
@@ -997,8 +998,8 @@ def indexer_select_post_process(
997998
if num_tokens > 0:
998999
weights, _ = self.weights_proj(x)
9991000

1000-
q_li, _ = self.wq_b(q_c) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
1001-
q_li = q_li.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
1001+
q_li, _ = self.wq_b(q_c) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
1002+
q_li = q_li.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
10021003

10031004
# rope
10041005
if HAS_TRITON:
@@ -1014,16 +1015,15 @@ def indexer_select_post_process(
10141015
q_li,
10151016
[self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
10161017
dim=-1,
1017-
) # [b,s,64,64+64]
1018+
) # [b,s,64,64+64]
10181019

10191020
q_li_pe = torch_npu.npu_rotary_mul(
10201021
q_li_pe.unsqueeze(2),
10211022
cos,
10221023
sin,
10231024
).squeeze(2)
10241025

1025-
q_li = torch.cat([q_li_pe, q_li_nope], dim=-1) # [b*s,64,128]
1026-
1026+
q_li = torch.cat([q_li_pe, q_li_nope], dim=-1) # [b*s,64,128]
10271027

10281028
# =========================
10291029
# step3: run lightning indexer
@@ -1043,9 +1043,9 @@ def indexer_select_post_process(
10431043
query=q_li,
10441044
key=kv_cache[2],
10451045
weights=weights,
1046-
actual_seq_lengths_query=actual_seq_lengths_query[:attn_metadata.num_actual_seqs],
1047-
actual_seq_lengths_key=actual_seq_lengths_key[:attn_metadata.num_actual_seqs],
1048-
block_table=attn_metadata.block_table[:attn_metadata.num_actual_seqs],
1046+
actual_seq_lengths_query=actual_seq_lengths_query[: attn_metadata.num_actual_seqs],
1047+
actual_seq_lengths_key=actual_seq_lengths_key[: attn_metadata.num_actual_seqs],
1048+
block_table=attn_metadata.block_table[: attn_metadata.num_actual_seqs],
10491049
layout_query="TND",
10501050
layout_key="PA_BSND",
10511051
sparse_count=sparse_count,
@@ -1289,7 +1289,7 @@ def forward(
12891289

12901290
k_li = self._get_full_kv(k_li, attn_metadata)
12911291

1292-
if kv_cache is not None and (not attn_metadata.skip or attn_metadata.non_skip_num_actual_tokens > 0):
1292+
if kv_cache is not None and (not attn_metadata.skip or attn_metadata.non_skip_num_actual_tokens > 0):
12931293
if self.is_kv_producer:
12941294
attn_metadata.reshape_cache_event = torch.npu.Event()
12951295
torch_npu.npu_scatter_nd_update_(

vllm_ascend/attention/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from dataclasses import dataclass, field
22
from functools import lru_cache
33
from typing import Any
4-
import numpy as np
54

5+
import numpy as np
66
import torch
77
import torch.nn.functional as F
88
from vllm.config import VllmConfig, get_current_vllm_config
@@ -121,6 +121,7 @@ class AscendPrefillContextParallelMetadata:
121121
# the number of tokens padded in linear-attn per rank
122122
pcp_padded_tokens_fla: int = 0
123123

124+
124125
@dataclass
125126
class AscendLightningIndexerMetadata:
126127
li_reorder_indices: torch.Tensor = None
@@ -130,6 +131,7 @@ class AscendLightningIndexerMetadata:
130131
top_k_indices_of_skipped_queries: torch.Tensor = None
131132
num_of_non_skip_tokens: int = 0
132133

134+
133135
@dataclass
134136
class AscendCommonAttentionMetadata(CommonAttentionMetadata):
135137
"""
@@ -342,11 +344,13 @@ def enabling_mlapo(vllm_config: VllmConfig) -> bool:
342344
)
343345
return bool(envs.VLLM_ASCEND_ENABLE_MLAPO and is_decode_instance)
344346

347+
345348
def to_numpy(x):
346349
if isinstance(x, torch.Tensor):
347350
return x.cpu().numpy()
348351
return x
349352

353+
350354
def get_sfa_skip_indices(num_comptuted_tokens, query_lens):
351355
num_comptuted_tokens = to_numpy(num_comptuted_tokens)
352356
query_lens = to_numpy(query_lens)

vllm_ascend/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1255,4 +1255,4 @@ def enable_lightning_indexer_skip() -> bool:
12551255
has_indexer_topk = hasattr(vllm_config.model_config, "hf_text_config") and hasattr(
12561256
vllm_config.model_config.hf_text_config, "index_topk"
12571257
)
1258-
return bool(has_indexer_topk and vllm_config.additional_config.get("enable_lightning_indexer_skip", False))
1258+
return bool(has_indexer_topk and vllm_config.additional_config.get("enable_lightning_indexer_skip", False))

vllm_ascend/worker/model_runner_v1.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,6 @@
9494
from vllm_ascend.attention.utils import (
9595
AscendCommonAttentionMetadata,
9696
AscendLightningIndexerMetadata,
97-
get_index_of_skipped_queries_numpy,
98-
get_sfa_skip_indices,
9997
hidden_states_reorder,
10098
maybe_pad_and_reorder_inputs,
10199
using_paged_attention,
@@ -127,6 +125,7 @@
127125
from vllm_ascend.utils import (
128126
calc_split_factor,
129127
check_gdn_layer,
128+
enable_lightning_indexer_skip,
130129
enable_sp,
131130
enable_sp_by_pass,
132131
global_stream,
@@ -135,7 +134,6 @@
135134
lmhead_tp_enable,
136135
set_weight_prefetch_method,
137136
vllm_version_is,
138-
enable_lightning_indexer_skip
139137
)
140138
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
141139
from vllm_ascend.worker.pcp_utils import PCPManager

0 commit comments

Comments
 (0)