Skip to content

Commit 511de2f

Browse files
committed
refact attn metadata build
Signed-off-by: weiguihua2 <[email protected]>
1 parent 7a6fd01 commit 511de2f

File tree

8 files changed

+133
-88
lines changed

8 files changed

+133
-88
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,21 @@
2020
from typing import List, Optional, Tuple, Type
2121

2222
import torch
23-
import torch_npu
2423
import torch.nn as nn
25-
from vllm.config import VllmConfig
24+
import torch_npu
2625
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
2726
AttentionLayer, AttentionType)
2827
from vllm.attention.backends.utils import CommonAttentionState
28+
from vllm.config import VllmConfig
2929
from vllm.forward_context import ForwardContext, get_forward_context
30-
from vllm.utils import direct_register_custom_op, cdiv
30+
from vllm.utils import cdiv, direct_register_custom_op
3131
from vllm.v1.core.sched.output import SchedulerOutput
3232

33+
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
3334
from vllm_ascend.ops.attention import vanilla_chunked_prefill
3435
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
3536
nd_to_nz_2d, nd_to_nz_spec)
3637
from vllm_ascend.worker.npu_input_batch import InputBatch
37-
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
3838

3939

4040
class AscendAttentionBackend(AttentionBackend):
@@ -160,9 +160,11 @@ class AscendMetadata:
160160

161161
class AscendAttentionMetadataBuilder:
162162

163-
def __init__(self,
163+
def __init__(
164+
self,
164165
vllm_config: VllmConfig,
165-
device: torch.device,):
166+
device: torch.device,
167+
):
166168
self.vllm_config = vllm_config
167169
self.model_config = vllm_config.model_config
168170
self.device = device
@@ -173,24 +175,33 @@ def reorder_batch(self, input_batch: "InputBatch",
173175
scheduler_output: "SchedulerOutput") -> bool:
174176
return False
175177

176-
def build(self,
178+
def build(
179+
self,
177180
common_attn_metadata: AscendCommonAttentionMetadata,
178-
model: nn.Module,):
181+
model: nn.Module,
182+
):
179183
num_reqs = common_attn_metadata.num_reqs
180184
num_actual_tokens = common_attn_metadata.num_actual_tokens
181-
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1]
185+
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
186+
num_reqs
187+
+ 1]
182188

183189
block_table = common_attn_metadata.block_table_tensor
184190
block_table[:num_reqs, :self.max_num_blocks_per_req] = (
185191
block_table[:num_reqs])
186192

187193
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
188194
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
189-
slot_mapping = common_attn_metadata.slot_mapping_cpu[:num_actual_tokens].to(
190-
self.device, non_blocking=True)
195+
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
196+
num_actual_tokens].to(
197+
self.device,
198+
non_blocking=
199+
True)
191200
attn_mask = common_attn_metadata.attn_mask
192201
attn_state = common_attn_metadata.attn_state
193-
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1]
202+
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
203+
num_reqs
204+
+ 1]
194205
query_start_loc = query_start_loc_cpu.to(self.device,
195206
non_blocking=True)
196207

vllm_ascend/attention/mla_v1.py

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33

44
import numpy as np
55
import torch
6-
import torch_npu
76
import torch.nn as nn
7+
import torch_npu
88
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
99
AttentionMetadata,
1010
MLAAttentionImpl)
1111
from vllm.attention.backends.utils import PAD_SLOT_ID
12-
from vllm.config import get_current_vllm_config, VllmConfig
12+
from vllm.config import VllmConfig, get_current_vllm_config
1313
from vllm.distributed import get_tensor_model_parallel_world_size
1414
from vllm.model_executor.layers.linear import (LinearBase,
1515
UnquantizedLinearMethod)
@@ -18,15 +18,15 @@
1818
import vllm_ascend.envs as envs_ascend
1919
from vllm_ascend.ascend_config import get_ascend_config
2020
from vllm_ascend.attention.attention_v1 import AscendAttentionState
21+
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
22+
split_decodes_and_prefills)
2123
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
2224
from vllm_ascend.multistream.context import get_multistream_comm_context
2325
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
2426
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
2527
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
2628
from vllm_ascend.utils import npu_prefetch
2729
from vllm_ascend.worker.npu_input_batch import InputBatch
28-
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,split_decodes_and_prefills)
29-
3030

3131
if TYPE_CHECKING:
3232
from vllm.v1.core.sched.output import SchedulerOutput
@@ -185,7 +185,8 @@ def __init__(self,
185185
self.device = device
186186
scheduler_config = vllm_config.scheduler_config
187187
self.block_size = vllm_config.cache_config.block_size
188-
self.max_blocks = (vllm_config.model_config.max_model_len + self.block_size - 1) // self.block_size
188+
self.max_blocks = (vllm_config.model_config.max_model_len +
189+
self.block_size - 1) // self.block_size
189190
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
190191
if self.chunked_prefill_enabled:
191192
self.chunked_prefill_workspace_size = min(
@@ -278,13 +279,13 @@ def reorder_batch(self, input_batch: "InputBatch",
278279
def _get_graph_runner_block_tables(
279280
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
280281
num_blocks = block_tables.size(1)
281-
if num_blocks <= self.max_blocks:
282-
return block_tables[:num_seqs, :num_blocks]
283-
else:
284-
return block_tables[:num_seqs, :self.max_blocks]
282+
num_blocks = min(num_blocks, self.max_blocks)
283+
return block_tables[:num_seqs, :num_blocks]
285284

286285
def build_torchair_graph_dummy(
287-
self, common_attn_metadata: AscendCommonAttentionMetadata,) -> AscendMLAMetadata:
286+
self,
287+
common_attn_metadata: AscendCommonAttentionMetadata,
288+
) -> AscendMLAMetadata:
288289
device = self.device
289290
num_reqs = common_attn_metadata.num_reqs
290291
block_table = torch.zeros((num_reqs, self.max_blocks),
@@ -332,7 +333,8 @@ def build_torchair_graph_dummy(
332333
seq_lens_list=seq_lens_list,
333334
max_seq_lens=1,
334335
attn_mask=common_attn_metadata.spec_attn_mask,
335-
actual_seq_lengths_q=common_attn_metadata.actual_seq_lengths_q[:num_reqs],
336+
actual_seq_lengths_q=common_attn_metadata.
337+
actual_seq_lengths_q[:num_reqs],
336338
sin=sin,
337339
cos=cos,
338340
)
@@ -362,26 +364,42 @@ def build(
362364
num_actual_tokens = common_attn_metadata.num_actual_tokens
363365
query_start_loc = common_attn_metadata.query_start_loc
364366
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
367+
if self.torchair_graph_enabled and common_attn_metadata.attn_state in [
368+
AscendAttentionState.DecodeOnly,
369+
AscendAttentionState.SpecDecoding
370+
]:
371+
decode_threshold = common_attn_metadata.decode_token_per_req
372+
else:
373+
# TODO(xyx): remove the if condition after mla supports torch mode speculative decoding
374+
decode_threshold = 1
365375
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
366-
split_decodes_and_prefills(common_attn_metadata)
376+
split_decodes_and_prefills(common_attn_metadata, decode_threshold=decode_threshold)
367377
assert num_decodes + num_prefills == num_reqs
378+
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
368379

369380
# Note(simon): be careful about the CPU <> GPU memory movement in this
370381
# function. We should avoid GPU -> CPU sync as much as possible because
371382
# it blocks on all previous kernels.
372383
device = self.device
373384

374385
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
375-
slot_mapping = common_attn_metadata.slot_mapping_cpu[:num_actual_tokens].to(
376-
device, non_blocking=True)
386+
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
387+
num_actual_tokens].to(
388+
device,
389+
non_blocking=
390+
True)
377391
# input_positions = common_attn_metadata.positions_cpu[:num_actual_tokens].to(
378392
# device, non_blocking=True).long()
379-
380-
input_positions = common_attn_metadata.positions[:num_actual_tokens].long()
393+
394+
input_positions = common_attn_metadata.positions[:
395+
num_actual_tokens].long(
396+
)
381397

382398
if self.cos_cache is None:
383-
self.cos_cache = model.model.layers[0].self_attn.rotary_emb.cos_cached
384-
self.sin_cache = model.model.layers[0].self_attn.rotary_emb.sin_cached
399+
self.cos_cache = model.model.layers[
400+
0].self_attn.rotary_emb.cos_cached
401+
self.sin_cache = model.model.layers[
402+
0].self_attn.rotary_emb.sin_cached
385403
if self.cos_cache.dtype != self.model_config.dtype: # type: ignore
386404
self.cos_cache = self.cos_cache.to( # type: ignore
387405
self.model_config.dtype) # type: ignore
@@ -392,7 +410,7 @@ def build(
392410
query_lens = query_seq_lens_cpu[:num_reqs]
393411
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
394412
num_computed_tokens_cpu = (seq_lens - query_lens)
395-
413+
396414
prefill_metadata = None
397415
chunked_context_metadata = None
398416
if num_prefills > 0:
@@ -477,8 +495,8 @@ def build(
477495
pad_value = 0
478496
num_token_pad_size = graph_pad_size - num_decode_tokens
479497
num_reqs_pad_size = (
480-
graph_pad_size // common_attn_metadata.decode_token_per_req -
481-
num_reqs)
498+
graph_pad_size //
499+
common_attn_metadata.decode_token_per_req - num_reqs)
482500
padded_seq_lens = seq_lens.tolist(
483501
) + [pad_value] * num_reqs_pad_size
484502
else:
@@ -506,8 +524,8 @@ def build(
506524
input_positions = torch.cat(
507525
[input_positions, position_padding])
508526
actual_seq_lengths_q = query_start_loc[1:].tolist(
509-
) + common_attn_metadata.actual_seq_lengths_q[num_reqs:num_reqs +
510-
num_reqs_pad_size]
527+
) + common_attn_metadata.actual_seq_lengths_q[
528+
num_reqs:num_reqs + num_reqs_pad_size]
511529
else:
512530
seq_lens_list = seq_lens.tolist()
513531
# mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens)

vllm_ascend/attention/utils.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from dataclasses import dataclass
2-
from enum import Enum
3-
4-
from vllm.config import SpeculativeConfig
2+
from typing import Any
53

64
import torch
75

@@ -29,26 +27,32 @@ class AscendCommonAttentionMetadata:
2927
"""Total number of tokens in batch"""
3028

3129
max_query_len: int
30+
"""Max token number of request in batch"""
3231

3332
decode_token_per_req: int
33+
"""decode token number per request"""
3434

3535
block_table_tensor: torch.Tensor
36+
3637
slot_mapping_cpu: torch.Tensor
3738

3839
actual_seq_lengths_q: list[int] = None
3940

4041
positions: torch.Tensor = None
4142

4243
attn_mask: torch.Tensor = None
44+
4345
spec_attn_mask: torch.Tensor = None
44-
attn_state: Enum = None
45-
46+
47+
attn_state: Any = None
48+
4649
enable_dbo_across_dp: bool = False
4750

4851
is_only_prefill: bool = False
4952

5053
graph_pad_size: int = -1
5154

55+
5256
@dataclass
5357
class TorchairCommonAttentionMetadata:
5458
"""
@@ -60,6 +64,7 @@ class TorchairCommonAttentionMetadata:
6064

6165
num_reqs: int
6266
"""Number of requests"""
67+
6368
num_actual_tokens: int
6469
"""Total number of tokens in batch"""
6570

@@ -68,6 +73,7 @@ class TorchairCommonAttentionMetadata:
6873
actual_seq_lengths_q: list[int] = None
6974

7075
attn_mask: torch.Tensor = None
76+
7177
spec_attn_mask: torch.Tensor = None
7278

7379
graph_pad_size: int = -1
@@ -113,4 +119,3 @@ def split_decodes_and_prefills(
113119
num_decode_tokens = query_start_loc[first_prefill].item()
114120
num_prefill_tokens = num_tokens - num_decode_tokens
115121
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
116-

0 commit comments

Comments
 (0)