Skip to content

Commit cd337a7

Browse files
committed
fix ci
1 parent 9e2ddee commit cd337a7

File tree

2 files changed

+21
-14
lines changed

2 files changed

+21
-14
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,8 @@ def __init__(self,
185185
self.device = device
186186
scheduler_config = vllm_config.scheduler_config
187187
self.block_size = vllm_config.cache_config.block_size
188-
self.max_blocks = (vllm_config.model_config.max_model_len + self.block_size - 1) // self.block_size
188+
self.max_blocks = (vllm_config.model_config.max_model_len +
189+
self.block_size - 1) // self.block_size
189190
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
190191
if self.chunked_prefill_enabled:
191192
self.chunked_prefill_workspace_size = min(
@@ -216,7 +217,10 @@ def __init__(self,
216217
self.cos_cache = None
217218
self.sin_cache = None
218219
self.prefill_attn_mask = torch.triu(
219-
torch.ones(512, 512, device=self.device, dtype=self.model_config.dtype),
220+
torch.ones(512,
221+
512,
222+
device=self.device,
223+
dtype=self.model_config.dtype),
220224
1) # 512: mask only support 512
221225

222226
def reorder_batch(self, input_batch: "InputBatch",
@@ -384,13 +388,13 @@ def build(
384388

385389
num_computed_tokens_cpu = (common_attn_metadata.seq_lens_cpu -
386390
query_seq_lens_cpu)
387-
391+
388392
use_torchair_graph = num_token_pad_size != -1
389393
if use_torchair_graph and self.runner.attn_state in [
390-
AscendAttentionState.DecodeOnly,
391-
AscendAttentionState.SpecDecoding
392-
]:
393-
decode_threshold = self.runner.decode_token_per_req
394+
AscendAttentionState.DecodeOnly,
395+
AscendAttentionState.SpecDecoding
396+
]:
397+
decode_threshold = self.runner.decode_token_per_req
394398
else:
395399
# TODO(xyx): remove the if condition after mla supports torch mode speculative decoding
396400
decode_threshold = 1
@@ -420,16 +424,16 @@ def build(
420424
prefill_metadata = None
421425
if num_prefills > 0:
422426
reqs_start = num_decodes # prefill_start
423-
427+
424428
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
425429
max_context_len_cpu = context_lens_cpu.max().item()
426430
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
427431
prefill_query_start_loc = query_start_loc[
428432
reqs_start:] - query_start_loc[reqs_start]
429-
433+
430434
tokens_start = num_decode_tokens
431435
chunked_context_metadata = None
432-
436+
433437
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
434438
# currently we allocate an equal amount of workspace for each
435439
# prefill in the batch, we could probably use a more advanced
@@ -504,7 +508,7 @@ def build(
504508
decode_metadata = None
505509
use_torchair_graph = num_token_pad_size != -1
506510
if num_decodes > 0:
507-
actual_seq_lengths_q = query_start_loc[1:num_decodes+1].tolist()
511+
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist()
508512
max_seq_lens = seq_lens[:num_decodes].max().item()
509513
seq_lens = seq_lens[:num_decodes]
510514
input_positions = input_positions[:num_decode_tokens]
@@ -953,7 +957,10 @@ def _forward_decode(
953957
self.qk_rope_head_dim)
954958
input_layout = "BNSD"
955959

956-
if attn_metadata.attn_state in [AscendAttentionState.SpecDecoding, AscendAttentionState.ChunkedPrefill]:
960+
if attn_metadata.attn_state in [
961+
AscendAttentionState.SpecDecoding,
962+
AscendAttentionState.ChunkedPrefill
963+
]:
957964
assert num_tokens % (1 + self.spec_token_num) == 0
958965
input_layout = "TND"
959966
# [bs * q_seq_len, num_heads_per_rank, dim]

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -887,8 +887,8 @@ def _process_reqs(
887887
assert total_num_scheduled_tokens > 0
888888
num_reqs = self.input_batch.num_reqs
889889
assert num_reqs > 0
890-
if (self.use_aclgraph and total_num_scheduled_tokens
891-
<= self.aclgraph_batch_sizes[-1]):
890+
if (self.use_aclgraph and
891+
total_num_scheduled_tokens <= self.aclgraph_batch_sizes[-1]):
892892
# Add padding to the batch size.
893893
num_input_tokens = self.vllm_config.pad_for_cudagraph(
894894
total_num_scheduled_tokens)

0 commit comments

Comments
 (0)