Skip to content

Commit 8c74a86

Browse files
committed
refact attn metadata build
Signed-off-by: weiguihua2 <[email protected]>
1 parent be66ea7 commit 8c74a86

File tree

6 files changed

+45
-42
lines changed

6 files changed

+45
-42
lines changed

vllm_ascend/attention/attention_v1_torchair.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
3434
nd_to_nz_2d)
3535
from vllm_ascend.worker.npu_input_batch import InputBatch
36-
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, get_decode_token_per_req
36+
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
3737

3838

3939
class AscendAttentionTorchairBackend(AttentionBackend):
@@ -157,7 +157,7 @@ def __init__(self,
157157
self.device = device
158158
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
159159
vllm_config.cache_config.block_size)
160-
self.decode_token_per_req = get_decode_token_per_req(vllm_config.speculative_config)
160+
self.max_blocks = (self.model_config.max_model_len + vllm_config.cache_config.block_size - 1) // vllm_config.cache_config.block_size
161161

162162
def reorder_batch(self, input_batch: "InputBatch",
163163
scheduler_output: "SchedulerOutput") -> bool:
@@ -175,7 +175,7 @@ def build_torchair_graph_dummy(
175175
self, common_attn_metadata: AscendCommonAttentionMetadata) -> AscendTorchairMetadata:
176176
device = self.device
177177
num_reqs = common_attn_metadata.num_reqs
178-
_, max_blocks = self.runner.graph_block_tables.shape
178+
_, max_blocks = self.max_blocks
179179
block_table = torch.zeros((num_reqs, max_blocks),
180180
dtype=torch.int32,
181181
device=device)
@@ -257,7 +257,7 @@ def build(self,
257257
pad_value = 0
258258
num_token_pad_size = graph_pad_size - num_actual_tokens
259259
num_reqs_pad_size = (
260-
graph_pad_size // self.decode_token_per_req -
260+
graph_pad_size // common_attn_metadata.decode_token_per_req -
261261
num_reqs)
262262
pad_value = 1
263263
padded_seq_lens = seq_lens.tolist() + [pad_value

vllm_ascend/attention/mla_v1.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
2626
from vllm_ascend.utils import npu_prefetch
2727
from vllm_ascend.worker.npu_input_batch import InputBatch
28-
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,split_decodes_and_prefills, get_decode_token_per_req)
28+
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,split_decodes_and_prefills)
2929

3030

3131
if TYPE_CHECKING:
@@ -186,7 +186,6 @@ def __init__(self,
186186
scheduler_config = vllm_config.scheduler_config
187187
self.block_size = vllm_config.cache_config.block_size
188188
self.max_blocks = (vllm_config.model_config.max_model_len + self.block_size - 1) // self.block_size
189-
self.decode_token_per_req = get_decode_token_per_req(vllm_config.speculative_config)
190189
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
191190
if self.chunked_prefill_enabled:
192191
self.chunked_prefill_workspace_size = min(
@@ -288,13 +287,13 @@ def build_torchair_graph_dummy(
288287
self, common_attn_metadata: AscendCommonAttentionMetadata,) -> AscendMLAMetadata:
289288
device = self.device
290289
num_reqs = common_attn_metadata.num_reqs
291-
_, max_blocks = self.runner.graph_block_tables.shape
290+
_, max_blocks = self.max_blocks
292291
block_table = torch.zeros((num_reqs, max_blocks),
293292
dtype=torch.int32,
294293
device=device)
295294
block_table = self._get_graph_runner_block_tables(
296295
num_reqs, block_table)
297-
num_tokens = num_reqs * self.decode_token_per_req
296+
num_tokens = num_reqs * common_attn_metadata.decode_token_per_req
298297
seq_lens = torch.zeros(num_reqs, dtype=torch.int32, device=device)
299298
seq_lens_list = [0] * num_reqs
300299
input_positions = torch.zeros(num_tokens,
@@ -382,8 +381,8 @@ def build(
382381
input_positions = common_attn_metadata.positions[:num_actual_tokens].long()
383382

384383
if self.cos_cache is None:
385-
self.cos_cache = model.layers[0].self_attn.rotary_emb.cos_cached
386-
self.sin_cache = model.layers[0].self_attn.rotary_emb.sin_cached
384+
self.cos_cache = model.model.layers[0].self_attn.rotary_emb.cos_cached
385+
self.sin_cache = model.model.layers[0].self_attn.rotary_emb.sin_cached
387386
if self.cos_cache.dtype != self.model_config.dtype: # type: ignore
388387
self.cos_cache = self.cos_cache.to( # type: ignore
389388
self.model_config.dtype) # type: ignore
@@ -392,10 +391,9 @@ def build(
392391

393392
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
394393
query_lens = query_seq_lens_cpu[:num_reqs]
395-
num_computed_tokens_cpu = (common_attn_metadata.seq_lens_cpu -
396-
query_seq_lens_cpu)
397-
398394
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
395+
num_computed_tokens_cpu = (seq_lens - query_lens)
396+
399397
prefill_metadata = None
400398
chunked_context_metadata = None
401399
if num_prefills > 0:
@@ -418,12 +416,12 @@ def build(
418416
assert max_context_chunk > 0
419417
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
420418
chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \
421-
.unsqueeze(1).expand(-1, self._num_prefills) * max_context_chunk
419+
.unsqueeze(1).expand(-1, num_prefills) * max_context_chunk
422420
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
423421
chunk_starts + max_context_chunk)
424422
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
425423
cu_seq_lens_cpu = torch.zeros(num_chunks,
426-
self._num_prefills + 1,
424+
num_prefills + 1,
427425
dtype=torch.int32,
428426
pin_memory=True)
429427
torch.cumsum(chunk_seq_lens,

vllm_ascend/attention/utils.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from dataclasses import dataclass
2+
from enum import Enum
23

34
from vllm.config import SpeculativeConfig
4-
from vllm_ascend.attention.attention_v1 import AscendAttentionState
55

66
import torch
77

@@ -28,18 +28,20 @@ class AscendCommonAttentionMetadata:
2828
num_actual_tokens: int
2929
"""Total number of tokens in batch"""
3030

31-
actual_seq_lengths_q: list[int] = None
31+
max_query_len: int
32+
33+
decode_token_per_req: int
3234

3335
block_table_tensor: torch.Tensor
3436
slot_mapping_cpu: torch.Tensor
3537

38+
actual_seq_lengths_q: list[int] = None
39+
3640
positions: torch.Tensor = None
3741

3842
attn_mask: torch.Tensor = None
3943
spec_attn_mask: torch.Tensor = None
40-
attn_state: AscendAttentionState = None
41-
42-
max_query_len: int
44+
attn_state: Enum = None
4345

4446
enable_dbo_across_dp: bool = False
4547

@@ -61,6 +63,8 @@ class TorchairCommonAttentionMetadata:
6163
num_actual_tokens: int
6264
"""Total number of tokens in batch"""
6365

66+
decode_token_per_req: int
67+
6468
actual_seq_lengths_q: list[int] = None
6569

6670
attn_mask: torch.Tensor = None
@@ -110,11 +114,3 @@ def split_decodes_and_prefills(
110114
num_prefill_tokens = num_tokens - num_decode_tokens
111115
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
112116

113-
114-
def get_decode_token_per_req(speculative_config: SpeculativeConfig):
115-
decode_token_per_req = 1
116-
if not speculative_config:
117-
return decode_token_per_req
118-
spec_token_num = speculative_config.num_speculative_tokens
119-
assert spec_token_num > 0
120-
return decode_token_per_req + spec_token_num

vllm_ascend/worker/eagle_proposer_v1.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,10 @@ def propose(
140140
attn_mask=self.runner.attn_mask,
141141
spec_attn_mask=self.runner.spec_attn_mask,
142142
attn_state=self.runner.attn_state,
143+
decode_token_per_req=self.runner.decode_token_per_req,
143144
)
144145
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
145-
attn_metadata = self.runner.attn_metadata_builder.build(common_attn_metadata)
146+
attn_metadata = self.runner.attn_metadata_builder.build(common_attn_metadata, self.runner.model)
146147
if self.use_cuda_graph and \
147148
num_tokens <= self.cudagraph_batch_sizes[-1]:
148149
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@
9292
from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer
9393
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
9494
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
95-
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, get_decode_token_per_req
95+
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
9696

9797
if not vllm_version_is("0.10.0"):
9898
from vllm.tasks import GenerationTask, SupportedTask
@@ -234,9 +234,12 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
234234
self.drafter: Optional[Union[NgramProposer, EagleProposer,
235235
MtpProposer]] = None
236236
self.actual_seq_lengths_q = []
237-
self.decode_token_per_req = get_decode_token_per_req(self.speculative_config)
237+
self.decode_token_per_req = 1
238238
if self.speculative_config:
239239
self.use_spec_decode = True
240+
spec_token_num = self.speculative_config.num_speculative_tokens
241+
assert spec_token_num > 0
242+
self.decode_token_per_req = 1 + spec_token_num
240243
self.actual_seq_lengths_q = [
241244
len for len in
242245
range(self.decode_token_per_req, self.max_num_tokens +
@@ -813,8 +816,9 @@ def get_eagle_atten_dict(
813816
spec_attn_mask=self.spec_attn_mask,
814817
attn_state=self.attn_state,
815818
max_num_blocks_per_req=self.max_num_blocks_per_req,
819+
decode_token_per_req=self.decode_token_per_req,
816820
)
817-
attn_metadata_i = self.attn_metadata_builder.build(common_attn_metadata)
821+
attn_metadata_i = self.attn_metadata_builder.build(common_attn_metadata, self.model)
818822
for layer_name in kv_cache_group_spec.layer_names:
819823
attn_metadata[layer_name] = attn_metadata_i
820824

@@ -1233,9 +1237,11 @@ def _process_reqs(
12331237
attn_state=self.attn_state,
12341238
enable_dbo_across_dp=enable_dbo,
12351239
is_only_prefill=is_only_prefill,
1236-
graph_pad_size=self.graph_pad_size
1240+
max_query_len=max_num_scheduled_tokens,
1241+
graph_pad_size=self.graph_pad_size,
1242+
decode_token_per_req=self.decode_token_per_req,
12371243
)
1238-
attn_metadata = self.attn_metadata_builder.build(common_attn_metadata)
1244+
attn_metadata = self.attn_metadata_builder.build(common_attn_metadata, self.model)
12391245
if self.vllm_config.model_config.use_mla:
12401246
attn_metadata.num_input_tokens = num_input_tokens
12411247

vllm_ascend/worker/mtp_proposer_v1.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,10 @@ def propose(
181181
attn_mask=self.runner.attn_mask,
182182
spec_attn_mask=self.runner.spec_attn_mask,
183183
attn_state=self.runner.attn_state,
184-
graph_pad_size=extra_builder_kwargs['graph_pad_size']
184+
graph_pad_size=extra_builder_kwargs['graph_pad_size'],
185+
decode_token_per_req=self.runner.decode_token_per_req,
185186
)
186-
attn_metadata = self.runner.attn_metadata_builder.build(common_attn_metadata)
187+
attn_metadata = self.runner.attn_metadata_builder.build(common_attn_metadata, self.runner.model)
187188

188189
self.positions[:num_tokens] = target_positions
189190
self.hidden_states[:num_tokens] = target_hidden_states
@@ -294,12 +295,13 @@ def dummy_run(self,
294295
attn_metadata = None
295296
else:
296297
common_attn_metadata = TorchairCommonAttentionMetadata(
297-
num_reqs=num_reqs,
298-
num_actual_tokens=1,
299-
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
300-
attn_mask=self.runner.attn_mask,
301-
spec_attn_mask=self.runner.spec_attn_mask,
302-
)
298+
num_reqs=num_reqs,
299+
num_actual_tokens=1,
300+
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
301+
attn_mask=self.runner.attn_mask,
302+
spec_attn_mask=self.runner.spec_attn_mask,
303+
decode_token_per_req=self.runner.decode_token_per_req,
304+
)
303305
attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(common_attn_metadata)
304306

305307
input_ids = self.input_ids[:num_tokens]

0 commit comments

Comments
 (0)