2828from vllm_ascend .attention .mla_v1 import MAX_O_PROJ_PREFETCH_SIZE , MLAPO_MAX_SUPPORTED_TOKENS
2929from vllm_ascend .attention .utils import (
3030 AscendCommonAttentionMetadata ,
31+ AscendLightningIndexerMetadata ,
3132 ascend_chunked_prefill_workspace_size ,
3233 enable_cp ,
34+ get_index_of_skipped_queries_numpy ,
35+ get_sfa_skip_indices ,
3336 maybe_save_kv_layer_to_connector ,
3437 trans_rope_weight ,
3538 transdata ,
3639 wait_for_kv_layer_from_connector ,
37- get_sfa_skip_indices ,
38- get_index_of_skipped_queries_numpy ,
39- AscendLightningIndexerMetadata
4040)
4141from vllm_ascend .device .device_op import DeviceOperator
4242from vllm_ascend .distributed .utils import all_gather_async
5555 dispose_layer ,
5656 enable_dsa_cp ,
5757 enable_dsa_cp_with_layer_shard ,
58- enable_lightning_indexer_skip ,
5958 enable_dsa_cp_with_o_proj_tp ,
59+ enable_lightning_indexer_skip ,
6060 get_weight_prefetch_method ,
6161 maybe_trans_nz ,
6262)
@@ -244,7 +244,7 @@ def build(
244244
245245 cum_query_lens = common_attn_metadata .query_start_loc [1 : num_reqs + 1 ]
246246 seq_lens = common_attn_metadata .seq_lens [:num_reqs ]
247-
247+
248248 my_query_start_loc = common_attn_metadata .query_start_loc [: num_reqs + 1 ]
249249 tokens = my_query_start_loc [1 :] - my_query_start_loc [:- 1 ]
250250
@@ -332,9 +332,8 @@ def build(
332332 top_k_indices_skip_li_query = None
333333 skip = False
334334 if enable_lightning_indexer_skip ():
335-
336- li_reorder_indices , li_cum_query_lens , li_seq_lens , li_skiped_query_mask , num_of_non_skip_tokens = get_sfa_skip_indices (
337- seq_lens - tokens , tokens
335+ li_reorder_indices , li_cum_query_lens , li_seq_lens , li_skiped_query_mask , num_of_non_skip_tokens = (
336+ get_sfa_skip_indices (seq_lens - tokens , tokens )
338337 )
339338 skip = num_of_non_skip_tokens is not None
340339
@@ -343,7 +342,7 @@ def build(
343342 li_cum_query_lens , li_seq_lens , num_reqs , 2048
344343 )
345344 common_attn_metadata .lightning_indexer_metadata = AscendLightningIndexerMetadata (
346- li_reorder_indices = torch .from_numpy (li_reorder_indices )
345+ li_reorder_indices = torch .from_numpy (li_reorder_indices )
347346 .pin_memory ()
348347 .to (dtype = torch .int32 , device = self .device , non_blocking = True ),
349348 li_cum_query_lens = torch .from_numpy (li_cum_query_lens )
@@ -358,7 +357,7 @@ def build(
358357 top_k_indices_of_skipped_queries = torch .from_numpy (top_k_indices_of_skipped_queries_numpy )
359358 .pin_memory ()
360359 .to (dtype = torch .int32 , device = self .device , non_blocking = True ),
361- num_of_non_skip_tokens = num_of_non_skip_tokens
360+ num_of_non_skip_tokens = num_of_non_skip_tokens ,
362361 )
363362 li_reorder_indices = common_attn_metadata .lightning_indexer_metadata .li_reorder_indices
364363 input_positions_pad = torch .zeros_like (input_positions )
@@ -374,7 +373,9 @@ def build(
374373 slot_mapping = slot_mapping_pad
375374 input_positions = input_positions_pad
376375 cos , sin = get_cos_and_sin_mla (input_positions , True )
377- top_k_indices_skip_li_query = common_attn_metadata .lightning_indexer_metadata .top_k_indices_of_skipped_queries
376+ top_k_indices_skip_li_query = (
377+ common_attn_metadata .lightning_indexer_metadata .top_k_indices_of_skipped_queries
378+ )
378379
379380 return self .metadata_cls ( # type: ignore
380381 num_input_tokens = common_attn_metadata .num_input_tokens ,
@@ -389,10 +390,10 @@ def build(
389390 sin = sin [:num_input_tokens ],
390391 cos = cos [:num_input_tokens ],
391392 dsa_cp_context = dsa_cp_context ,
392- num_actual_seqs = num_reqs ,
393- top_k_indices_skip_li_query = top_k_indices_skip_li_query ,
394- non_skip_num_actual_tokens = num_of_non_skip_tokens ,
395- skip = skip
393+ num_actual_seqs = num_reqs ,
394+ top_k_indices_skip_li_query = top_k_indices_skip_li_query ,
395+ non_skip_num_actual_tokens = num_of_non_skip_tokens ,
396+ skip = skip ,
396397 )
397398
398399 def build_for_graph_capture (
@@ -997,8 +998,8 @@ def indexer_select_post_process(
997998 if num_tokens > 0 :
998999 weights , _ = self .weights_proj (x )
9991000
1000- q_li , _ = self .wq_b (q_c ) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
1001- q_li = q_li .view (- 1 , self .n_head , self .head_dim ) # [n_toks,64,128]
1001+ q_li , _ = self .wq_b (q_c ) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
1002+ q_li = q_li .view (- 1 , self .n_head , self .head_dim ) # [n_toks,64,128]
10021003
10031004 # rope
10041005 if HAS_TRITON :
@@ -1014,16 +1015,15 @@ def indexer_select_post_process(
10141015 q_li ,
10151016 [self .qk_rope_head_dim , self .head_dim - self .qk_rope_head_dim ],
10161017 dim = - 1 ,
1017- ) # [b,s,64,64+64]
1018+ ) # [b,s,64,64+64]
10181019
10191020 q_li_pe = torch_npu .npu_rotary_mul (
10201021 q_li_pe .unsqueeze (2 ),
10211022 cos ,
10221023 sin ,
10231024 ).squeeze (2 )
10241025
1025- q_li = torch .cat ([q_li_pe , q_li_nope ], dim = - 1 ) # [b*s,64,128]
1026-
1026+ q_li = torch .cat ([q_li_pe , q_li_nope ], dim = - 1 ) # [b*s,64,128]
10271027
10281028 # =========================
10291029 # step3: run lightning indexer
@@ -1043,9 +1043,9 @@ def indexer_select_post_process(
10431043 query = q_li ,
10441044 key = kv_cache [2 ],
10451045 weights = weights ,
1046- actual_seq_lengths_query = actual_seq_lengths_query [:attn_metadata .num_actual_seqs ],
1047- actual_seq_lengths_key = actual_seq_lengths_key [:attn_metadata .num_actual_seqs ],
1048- block_table = attn_metadata .block_table [:attn_metadata .num_actual_seqs ],
1046+ actual_seq_lengths_query = actual_seq_lengths_query [: attn_metadata .num_actual_seqs ],
1047+ actual_seq_lengths_key = actual_seq_lengths_key [: attn_metadata .num_actual_seqs ],
1048+ block_table = attn_metadata .block_table [: attn_metadata .num_actual_seqs ],
10491049 layout_query = "TND" ,
10501050 layout_key = "PA_BSND" ,
10511051 sparse_count = sparse_count ,
@@ -1289,7 +1289,7 @@ def forward(
12891289
12901290 k_li = self ._get_full_kv (k_li , attn_metadata )
12911291
1292- if kv_cache is not None and (not attn_metadata .skip or attn_metadata .non_skip_num_actual_tokens > 0 ):
1292+ if kv_cache is not None and (not attn_metadata .skip or attn_metadata .non_skip_num_actual_tokens > 0 ):
12931293 if self .is_kv_producer :
12941294 attn_metadata .reshape_cache_event = torch .npu .Event ()
12951295 torch_npu .npu_scatter_nd_update_ (
0 commit comments