3
3
4
4
import numpy as np
5
5
import torch
6
- import torch_npu
7
6
import torch .nn as nn
7
+ import torch_npu
8
8
from vllm .attention .backends .abstract import (AttentionBackend , AttentionLayer ,
9
9
AttentionMetadata ,
10
10
MLAAttentionImpl )
11
11
from vllm .attention .backends .utils import PAD_SLOT_ID
12
- from vllm .config import get_current_vllm_config , VllmConfig
12
+ from vllm .config import VllmConfig , get_current_vllm_config
13
13
from vllm .distributed import get_tensor_model_parallel_world_size
14
14
from vllm .model_executor .layers .linear import (LinearBase ,
15
15
UnquantizedLinearMethod )
18
18
import vllm_ascend .envs as envs_ascend
19
19
from vllm_ascend .ascend_config import get_ascend_config
20
20
from vllm_ascend .attention .attention_v1 import AscendAttentionState
21
+ from vllm_ascend .attention .utils import (AscendCommonAttentionMetadata ,
22
+ split_decodes_and_prefills )
21
23
from vllm_ascend .multistream .base import MSAttentionMetadataSplitConfig
22
24
from vllm_ascend .multistream .context import get_multistream_comm_context
23
25
from vllm_ascend .multistream .ms_split import model_input_split_v1_mla_attn
24
26
from vllm_ascend .ops .attention import vanilla_chunked_prefill_mla
25
27
from vllm_ascend .torchair .utils import npu_stream_switch , npu_wait_tensor
26
28
from vllm_ascend .utils import npu_prefetch
27
29
from vllm_ascend .worker .npu_input_batch import InputBatch
28
- from vllm_ascend .attention .utils import (AscendCommonAttentionMetadata ,split_decodes_and_prefills )
29
-
30
30
31
31
if TYPE_CHECKING :
32
32
from vllm .v1 .core .sched .output import SchedulerOutput
@@ -185,7 +185,8 @@ def __init__(self,
185
185
self .device = device
186
186
scheduler_config = vllm_config .scheduler_config
187
187
self .block_size = vllm_config .cache_config .block_size
188
- self .max_blocks = (vllm_config .model_config .max_model_len + self .block_size - 1 ) // self .block_size
188
+ self .max_blocks = (vllm_config .model_config .max_model_len +
189
+ self .block_size - 1 ) // self .block_size
189
190
self .chunked_prefill_enabled = scheduler_config .chunked_prefill_enabled
190
191
if self .chunked_prefill_enabled :
191
192
self .chunked_prefill_workspace_size = min (
@@ -278,13 +279,13 @@ def reorder_batch(self, input_batch: "InputBatch",
278
279
def _get_graph_runner_block_tables (
279
280
self , num_seqs : int , block_tables : torch .Tensor ) -> torch .Tensor :
280
281
num_blocks = block_tables .size (1 )
281
- if num_blocks <= self .max_blocks :
282
- return block_tables [:num_seqs , :num_blocks ]
283
- else :
284
- return block_tables [:num_seqs , :self .max_blocks ]
282
+ num_blocks = min (num_blocks , self .max_blocks )
283
+ return block_tables [:num_seqs , :num_blocks ]
285
284
286
285
def build_torchair_graph_dummy (
287
- self , common_attn_metadata : AscendCommonAttentionMetadata ,) -> AscendMLAMetadata :
286
+ self ,
287
+ common_attn_metadata : AscendCommonAttentionMetadata ,
288
+ ) -> AscendMLAMetadata :
288
289
device = self .device
289
290
num_reqs = common_attn_metadata .num_reqs
290
291
block_table = torch .zeros ((num_reqs , self .max_blocks ),
@@ -332,7 +333,8 @@ def build_torchair_graph_dummy(
332
333
seq_lens_list = seq_lens_list ,
333
334
max_seq_lens = 1 ,
334
335
attn_mask = common_attn_metadata .spec_attn_mask ,
335
- actual_seq_lengths_q = common_attn_metadata .actual_seq_lengths_q [:num_reqs ],
336
+ actual_seq_lengths_q = common_attn_metadata .
337
+ actual_seq_lengths_q [:num_reqs ],
336
338
sin = sin ,
337
339
cos = cos ,
338
340
)
@@ -362,26 +364,42 @@ def build(
362
364
num_actual_tokens = common_attn_metadata .num_actual_tokens
363
365
query_start_loc = common_attn_metadata .query_start_loc
364
366
query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu
367
+ if self .torchair_graph_enabled and common_attn_metadata .attn_state in [
368
+ AscendAttentionState .DecodeOnly ,
369
+ AscendAttentionState .SpecDecoding
370
+ ]:
371
+ decode_threshold = common_attn_metadata .decode_token_per_req
372
+ else :
373
+ # TODO(xyx): remove the if condition after mla supports torch mode speculative decoding
374
+ decode_threshold = 1
365
375
num_decodes , num_prefills , num_decode_tokens , num_prefill_tokens = \
366
- split_decodes_and_prefills (common_attn_metadata )
376
+ split_decodes_and_prefills (common_attn_metadata , decode_threshold = decode_threshold )
367
377
assert num_decodes + num_prefills == num_reqs
378
+ assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
368
379
369
380
# Note(simon): be careful about the CPU <> GPU memory movement in this
370
381
# function. We should avoid GPU -> CPU sync as much as possible because
371
382
# it blocks on all previous kernels.
372
383
device = self .device
373
384
374
385
block_table = (common_attn_metadata .block_table_tensor [:num_reqs ])
375
- slot_mapping = common_attn_metadata .slot_mapping_cpu [:num_actual_tokens ].to (
376
- device , non_blocking = True )
386
+ slot_mapping = common_attn_metadata .slot_mapping_cpu [:
387
+ num_actual_tokens ].to (
388
+ device ,
389
+ non_blocking =
390
+ True )
377
391
# input_positions = common_attn_metadata.positions_cpu[:num_actual_tokens].to(
378
392
# device, non_blocking=True).long()
379
-
380
- input_positions = common_attn_metadata .positions [:num_actual_tokens ].long ()
393
+
394
+ input_positions = common_attn_metadata .positions [:
395
+ num_actual_tokens ].long (
396
+ )
381
397
382
398
if self .cos_cache is None :
383
- self .cos_cache = model .model .layers [0 ].self_attn .rotary_emb .cos_cached
384
- self .sin_cache = model .model .layers [0 ].self_attn .rotary_emb .sin_cached
399
+ self .cos_cache = model .model .layers [
400
+ 0 ].self_attn .rotary_emb .cos_cached
401
+ self .sin_cache = model .model .layers [
402
+ 0 ].self_attn .rotary_emb .sin_cached
385
403
if self .cos_cache .dtype != self .model_config .dtype : # type: ignore
386
404
self .cos_cache = self .cos_cache .to ( # type: ignore
387
405
self .model_config .dtype ) # type: ignore
@@ -392,7 +410,7 @@ def build(
392
410
query_lens = query_seq_lens_cpu [:num_reqs ]
393
411
seq_lens = common_attn_metadata .seq_lens_cpu [:num_reqs ]
394
412
num_computed_tokens_cpu = (seq_lens - query_lens )
395
-
413
+
396
414
prefill_metadata = None
397
415
chunked_context_metadata = None
398
416
if num_prefills > 0 :
@@ -477,8 +495,8 @@ def build(
477
495
pad_value = 0
478
496
num_token_pad_size = graph_pad_size - num_decode_tokens
479
497
num_reqs_pad_size = (
480
- graph_pad_size // common_attn_metadata . decode_token_per_req -
481
- num_reqs )
498
+ graph_pad_size //
499
+ common_attn_metadata . decode_token_per_req - num_reqs )
482
500
padded_seq_lens = seq_lens .tolist (
483
501
) + [pad_value ] * num_reqs_pad_size
484
502
else :
@@ -506,8 +524,8 @@ def build(
506
524
input_positions = torch .cat (
507
525
[input_positions , position_padding ])
508
526
actual_seq_lengths_q = query_start_loc [1 :].tolist (
509
- ) + common_attn_metadata .actual_seq_lengths_q [num_reqs : num_reqs +
510
- num_reqs_pad_size ]
527
+ ) + common_attn_metadata .actual_seq_lengths_q [
528
+ num_reqs : num_reqs + num_reqs_pad_size ]
511
529
else :
512
530
seq_lens_list = seq_lens .tolist ()
513
531
# mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens)
0 commit comments