@@ -185,7 +185,8 @@ def __init__(self,
185
185
self .device = device
186
186
scheduler_config = vllm_config .scheduler_config
187
187
self .block_size = vllm_config .cache_config .block_size
188
- self .max_blocks = (vllm_config .model_config .max_model_len + self .block_size - 1 ) // self .block_size
188
+ self .max_blocks = (vllm_config .model_config .max_model_len +
189
+ self .block_size - 1 ) // self .block_size
189
190
self .chunked_prefill_enabled = scheduler_config .chunked_prefill_enabled
190
191
if self .chunked_prefill_enabled :
191
192
self .chunked_prefill_workspace_size = min (
@@ -216,7 +217,10 @@ def __init__(self,
216
217
self .cos_cache = None
217
218
self .sin_cache = None
218
219
self .prefill_attn_mask = torch .triu (
219
- torch .ones (512 , 512 , device = self .device , dtype = self .model_config .dtype ),
220
+ torch .ones (512 ,
221
+ 512 ,
222
+ device = self .device ,
223
+ dtype = self .model_config .dtype ),
220
224
1 ) # 512: mask only support 512
221
225
222
226
def reorder_batch (self , input_batch : "InputBatch" ,
@@ -384,13 +388,13 @@ def build(
384
388
385
389
num_computed_tokens_cpu = (common_attn_metadata .seq_lens_cpu -
386
390
query_seq_lens_cpu )
387
-
391
+
388
392
use_torchair_graph = num_token_pad_size != - 1
389
393
if use_torchair_graph and self .runner .attn_state in [
390
- AscendAttentionState .DecodeOnly ,
391
- AscendAttentionState .SpecDecoding
392
- ]:
393
- decode_threshold = self .runner .decode_token_per_req
394
+ AscendAttentionState .DecodeOnly ,
395
+ AscendAttentionState .SpecDecoding
396
+ ]:
397
+ decode_threshold = self .runner .decode_token_per_req
394
398
else :
395
399
# TODO(xyx): remove the if condition after mla supports torch mode speculative decoding
396
400
decode_threshold = 1
@@ -420,16 +424,16 @@ def build(
420
424
prefill_metadata = None
421
425
if num_prefills > 0 :
422
426
reqs_start = num_decodes # prefill_start
423
-
427
+
424
428
context_lens_cpu = num_computed_tokens_cpu [reqs_start :num_reqs ]
425
429
max_context_len_cpu = context_lens_cpu .max ().item ()
426
430
num_prefills_with_context_cpu = (context_lens_cpu > 0 ).sum ().item ()
427
431
prefill_query_start_loc = query_start_loc [
428
432
reqs_start :] - query_start_loc [reqs_start ]
429
-
433
+
430
434
tokens_start = num_decode_tokens
431
435
chunked_context_metadata = None
432
-
436
+
433
437
if self .chunked_prefill_enabled and max_context_len_cpu > 0 :
434
438
# currently we allocate an equal amount of workspace for each
435
439
# prefill in the batch, we could probably use a more advanced
@@ -504,7 +508,7 @@ def build(
504
508
decode_metadata = None
505
509
use_torchair_graph = num_token_pad_size != - 1
506
510
if num_decodes > 0 :
507
- actual_seq_lengths_q = query_start_loc [1 :num_decodes + 1 ].tolist ()
511
+ actual_seq_lengths_q = query_start_loc [1 :num_decodes + 1 ].tolist ()
508
512
max_seq_lens = seq_lens [:num_decodes ].max ().item ()
509
513
seq_lens = seq_lens [:num_decodes ]
510
514
input_positions = input_positions [:num_decode_tokens ]
@@ -953,7 +957,10 @@ def _forward_decode(
953
957
self .qk_rope_head_dim )
954
958
input_layout = "BNSD"
955
959
956
- if attn_metadata .attn_state in [AscendAttentionState .SpecDecoding , AscendAttentionState .ChunkedPrefill ]:
960
+ if attn_metadata .attn_state in [
961
+ AscendAttentionState .SpecDecoding ,
962
+ AscendAttentionState .ChunkedPrefill
963
+ ]:
957
964
assert num_tokens % (1 + self .spec_token_num ) == 0
958
965
input_layout = "TND"
959
966
# [bs * q_seq_len, num_heads_per_rank, dim]
0 commit comments