Skip to content

Commit c416bbf

Browse files
committed
refact attn metadata build
Signed-off-by: weiguihua2 <[email protected]>
1 parent c4c07a0 commit c416bbf

File tree

5 files changed

+33
-25
lines changed

5 files changed

+33
-25
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
AttentionLayer, AttentionType)
2828
from vllm.attention.backends.utils import CommonAttentionState
2929
from vllm.forward_context import ForwardContext, get_forward_context
30-
from vllm.utils import direct_register_custom_op
30+
from vllm.utils import direct_register_custom_op, cdiv
3131
from vllm.v1.core.sched.output import SchedulerOutput
3232

3333
from vllm_ascend.ops.attention import vanilla_chunked_prefill
@@ -165,6 +165,8 @@ def __init__(self,
165165
self.vllm_config = vllm_config
166166
self.model_config = vllm_config.model_config
167167
self.device = device
168+
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
169+
vllm_config.cache_config.block_size)
168170

169171
def reorder_batch(self, input_batch: "InputBatch",
170172
scheduler_output: "SchedulerOutput") -> bool:
@@ -178,7 +180,7 @@ def build(self,
178180
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1]
179181

180182
block_table = common_attn_metadata.block_table_tensor
181-
block_table[:num_reqs, :common_attn_metadata.max_num_blocks_per_req] = (
183+
block_table[:num_reqs, :self.max_num_blocks_per_req] = (
182184
block_table[:num_reqs])
183185

184186
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]

vllm_ascend/attention/attention_v1_torchair.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import torch_npu
2424
import torch.nn as nn
2525
from vllm.config import VllmConfig
26+
from vllm.utils import cdiv
2627
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
2728
AttentionLayer, AttentionType)
2829
from vllm.attention.backends.utils import PAD_SLOT_ID, CommonAttentionState
@@ -32,7 +33,7 @@
3233
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
3334
nd_to_nz_2d)
3435
from vllm_ascend.worker.npu_input_batch import InputBatch
35-
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
36+
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, get_decode_token_per_req
3637

3738

3839
class AscendAttentionTorchairBackend(AttentionBackend):
@@ -154,6 +155,9 @@ def __init__(self,
154155
self.vllm_config = vllm_config
155156
self.model_config = vllm_config.model_config
156157
self.device = device
158+
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
159+
vllm_config.cache_config.block_size)
160+
self.decode_token_per_req = get_decode_token_per_req(vllm_config.speculative_config)
157161

158162
def reorder_batch(self, input_batch: "InputBatch",
159163
scheduler_output: "SchedulerOutput") -> bool:
@@ -214,7 +218,7 @@ def build(self,
214218
num_actual_tokens = common_attn_metadata.num_actual_tokens
215219

216220
block_table = common_attn_metadata.block_table_tensor
217-
block_table[:num_reqs, :common_attn_metadata.max_num_blocks_per_req] = (
221+
block_table[:num_reqs, :self.max_num_blocks_per_req] = (
218222
block_table[:num_reqs])
219223

220224
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
@@ -253,7 +257,7 @@ def build(self,
253257
pad_value = 0
254258
num_token_pad_size = graph_pad_size - num_actual_tokens
255259
num_reqs_pad_size = (
256-
graph_pad_size // common_attn_metadata.decode_token_per_req -
260+
graph_pad_size // self.decode_token_per_req -
257261
num_reqs)
258262
pad_value = 1
259263
padded_seq_lens = seq_lens.tolist() + [pad_value

vllm_ascend/attention/mla_v1.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
2626
from vllm_ascend.utils import npu_prefetch
2727
from vllm_ascend.worker.npu_input_batch import InputBatch
28-
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,split_decodes_and_prefills)
28+
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,split_decodes_and_prefills, get_decode_token_per_req)
2929

3030

3131
if TYPE_CHECKING:
@@ -186,6 +186,7 @@ def __init__(self,
186186
scheduler_config = vllm_config.scheduler_config
187187
self.block_size = vllm_config.cache_config.block_size
188188
self.max_blocks = (vllm_config.model_config.max_model_len + self.block_size - 1) // self.block_size
189+
self.decode_token_per_req = get_decode_token_per_req(vllm_config.speculative_config)
189190
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
190191
if self.chunked_prefill_enabled:
191192
self.chunked_prefill_workspace_size = min(
@@ -293,7 +294,7 @@ def build_torchair_graph_dummy(
293294
device=device)
294295
block_table = self._get_graph_runner_block_tables(
295296
num_reqs, block_table)
296-
num_tokens = num_reqs * common_attn_metadata.decode_token_per_req
297+
num_tokens = num_reqs * self.decode_token_per_req
297298
seq_lens = torch.zeros(num_reqs, dtype=torch.int32, device=device)
298299
seq_lens_list = [0] * num_reqs
299300
input_positions = torch.zeros(num_tokens,

vllm_ascend/attention/utils.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from dataclasses import dataclass
22

3+
from vllm.config import SpeculativeConfig
34
from vllm_ascend.attention.attention_v1 import AscendAttentionState
45

56
import torch
@@ -14,26 +15,23 @@ class AscendCommonAttentionMetadata:
1415
For many of the tensors we keep both GPU and CPU versions.
1516
"""
1617

17-
query_start_loc: torch.Tensor = None
18-
query_start_loc_cpu: torch.Tensor = None
18+
query_start_loc: torch.Tensor
19+
query_start_loc_cpu: torch.Tensor
1920
"""(batch_size + 1,), the start location of each request in query Tensor"""
2021

21-
seq_lens: torch.Tensor = None
22-
seq_lens_cpu: torch.Tensor = None
22+
seq_lens_cpu: torch.Tensor
2323
"""(batch_size,), the length of each request including both computed tokens
2424
and newly scheduled tokens"""
2525

2626
num_reqs: int
2727
"""Number of requests"""
2828
num_actual_tokens: int
2929
"""Total number of tokens in batch"""
30-
max_query_len: int
31-
"""Longest query in batch"""
3230

3331
actual_seq_lengths_q: list[int] = None
3432

35-
block_table_tensor: torch.Tensor = None
36-
slot_mapping_cpu: torch.Tensor = None
33+
block_table_tensor: torch.Tensor
34+
slot_mapping_cpu: torch.Tensor
3735

3836
positions: torch.Tensor = None
3937

@@ -47,7 +45,7 @@ class AscendCommonAttentionMetadata:
4745

4846
enable_dbo_across_dp: bool = False
4947

50-
is_only_prefill: bool
48+
is_only_prefill: bool = False
5149

5250
graph_pad_size: int = -1
5351

@@ -70,8 +68,6 @@ class TorchairCommonAttentionMetadata:
7068
attn_mask: torch.Tensor = None
7169
spec_attn_mask: torch.Tensor = None
7270

73-
decode_token_per_req: int
74-
7571
graph_pad_size: int = -1
7672

7773

@@ -115,3 +111,12 @@ def split_decodes_and_prefills(
115111
num_decode_tokens = query_start_loc[first_prefill].item()
116112
num_prefill_tokens = num_tokens - num_decode_tokens
117113
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
114+
115+
116+
def get_decode_token_per_req(speculative_config: SpeculativeConfig):
117+
decode_token_per_req = 1
118+
if not speculative_config:
119+
return decode_token_per_req
120+
spec_token_num = speculative_config.num_speculative_tokens
121+
assert spec_token_num > 0
122+
return decode_token_per_req + spec_token_num

vllm_ascend/worker/model_runner_v1.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@
9292
from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer
9393
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
9494
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
95-
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
95+
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, get_decode_token_per_req
9696

9797
if not vllm_version_is("0.10.0"):
9898
from vllm.tasks import GenerationTask, SupportedTask
@@ -221,7 +221,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
221221
use_mla=self.model_config.use_mla,
222222
)
223223
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
224-
weakref.proxy(self))
224+
vllm_config, device)
225225
self.attn_mask_builder = AttentionMaskBuilder(
226226
min(self.model_config.max_model_len,
227227
int(os.getenv("PAGED_ATTENTION_MASK_LEN", 10000))), self.dtype)
@@ -234,13 +234,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
234234
self.drafter: Optional[Union[NgramProposer, EagleProposer,
235235
MtpProposer]] = None
236236
self.actual_seq_lengths_q = []
237-
self.spec_token_num = 0
238-
self.decode_token_per_req = 1
237+
self.decode_token_per_req = get_decode_token_per_req(self.speculative_config)
239238
if self.speculative_config:
240239
self.use_spec_decode = True
241-
self.spec_token_num = self.speculative_config.num_speculative_tokens
242-
assert self.spec_token_num > 0
243-
self.decode_token_per_req = 1 + self.spec_token_num
244240
self.actual_seq_lengths_q = [
245241
len for len in
246242
range(self.decode_token_per_req, self.max_num_tokens +

0 commit comments

Comments
 (0)