Skip to content

Commit f71f1af

Browse files
committed
MTP supports V1 scheduler
Signed-off-by: xuyexiong <[email protected]>
1 parent 17c2884 commit f71f1af

File tree

5 files changed

+316
-208
lines changed

5 files changed

+316
-208
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 62 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
import torch
2323
import torch_npu
24+
import torch.nn as nn
25+
from vllm.config import VllmConfig
2426
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
2527
AttentionLayer, AttentionType)
2628
from vllm.attention.backends.utils import CommonAttentionState
@@ -30,11 +32,10 @@
3032
from vllm.v1.core.sched.output import SchedulerOutput
3133
from vllm.v1.worker.gpu_input_batch import InputBatch
3234

33-
from vllm_ascend.attention.utils import \
34-
AscendCommonAttentionMetadata as CommonAttentionMetadata
3535
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
3636
from vllm_ascend.ops.attention import vanilla_chunked_prefill
3737
from vllm_ascend.utils import get_graph_params
38+
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
3839

3940

4041
class AscendAttentionBackend(AttentionBackend):
@@ -156,39 +157,48 @@ def split_metadata_for_multistream(
156157

157158
class AscendAttentionMetadataBuilder:
158159

159-
def __init__(self, runner):
160+
def __init__(
161+
self,
162+
vllm_config: VllmConfig,
163+
device: torch.device,
164+
runner
165+
):
166+
self.vllm_config = vllm_config
167+
self.model_config = vllm_config.model_config
168+
self.device = device
160169
self.runner = runner
161170

162171
def reorder_batch(self, input_batch: "InputBatch",
163172
scheduler_output: "SchedulerOutput") -> bool:
164173
return False
165174

166-
def build(self,
167-
num_reqs,
168-
num_actual_tokens,
169-
max_query_len,
170-
common_attn_metadata: CommonAttentionMetadata,
171-
enable_dbo_across_dp: bool = False,
172-
is_only_prefill: bool = False,
173-
*args,
174-
**kwargs):
175-
176-
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
177-
)
178-
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
179-
block_table[:num_reqs])
180-
181-
query_start_loc = common_attn_metadata.query_start_loc
182-
seq_lens = common_attn_metadata.seq_lens
175+
def build(
176+
self,
177+
common_attn_metadata: AscendCommonAttentionMetadata,
178+
):
179+
num_reqs = common_attn_metadata.num_reqs
180+
num_actual_tokens = common_attn_metadata.num_actual_tokens
181+
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
182+
num_reqs
183+
+ 1]
184+
185+
block_table = common_attn_metadata.block_table_tensor
186+
block_table[:num_reqs, :common_attn_metadata.
187+
max_num_blocks_per_req] = (block_table[:num_reqs])
188+
189+
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
183190
# TODO: Refactor these two param to common metadata in runners,
184191
# preparing for the hybrid KV groups feature
185-
query_lens = common_attn_metadata.query_lens or self.runner.query_lens
192+
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
186193
# Since FIA for GQA is not active now, we temporarily silence it
187194
seq_lens_list = common_attn_metadata.seq_lens_list
188195

189-
slot_mapping = self.runner.slot_mapping[:num_actual_tokens]
190-
attn_mask = self.runner.attn_mask
191-
attn_state = self.runner.attn_state
196+
slot_mapping = common_attn_metadata.slot_mapping_cpu[:num_actual_tokens].to(
197+
self.device, non_blocking=True)
198+
attn_mask = common_attn_metadata.attn_mask
199+
attn_state = common_attn_metadata.attn_state
200+
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1]
201+
query_start_loc = query_start_loc_cpu.to(self.device,non_blocking=True)
192202

193203
attn_metadata = AscendMetadata(
194204
num_actual_tokens=num_actual_tokens,
@@ -197,34 +207,53 @@ def build(self,
197207
query_lens=query_lens,
198208
seq_lens=seq_lens,
199209
seq_lens_list=seq_lens_list,
200-
max_query_len=max_query_len,
210+
max_query_len=common_attn_metadata.max_query_len,
201211
slot_mapping=slot_mapping,
202212
attn_mask=attn_mask,
203213
attn_state=attn_state,
204-
enable_dbo_across_dp=enable_dbo_across_dp,
205-
is_only_prefill=is_only_prefill)
214+
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
215+
is_only_prefill=common_attn_metadata.is_only_prefill)
206216
return attn_metadata
207217

208218
def build_dummy_metadata(self, num_actual_tokens, num_reqs,
209219
num_scheduled_tokens, attn_state):
210220
if attn_state == AscendAttentionState.DecodeOnly:
211221
# NOTE: We only need to pay attention to seq_lens_list and block_table here
212-
common_attn_metadata = CommonAttentionMetadata(
222+
common_attn_metadata = AscendCommonAttentionMetadata(
213223
seq_lens=torch.empty_like(self.runner.seq_lens_cpu).fill_(2))
214224

215225
block_table = self.runner.input_batch.block_table[0].block_table
216226
block_table[:num_reqs, 0] = torch.arange(1,
217227
num_reqs + 1,
218228
device=block_table.device,
219229
dtype=block_table.dtype)
230+
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
231+
)
232+
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
233+
block_table[:num_reqs])
234+
235+
query_start_loc = common_attn_metadata.query_start_loc
236+
seq_lens = common_attn_metadata.seq_lens
237+
query_lens = self.runner.query_lens
238+
seq_lens_list = None
220239

221-
attn_metadata = self.build(
222-
num_reqs=num_reqs,
240+
slot_mapping = self.runner.slot_mapping[:num_actual_tokens]
241+
attn_mask = self.runner.attn_mask
242+
attn_state = self.runner.attn_state
243+
244+
attn_metadata = AscendMetadata(
223245
num_actual_tokens=num_actual_tokens,
246+
block_tables=block_table,
247+
query_start_loc=query_start_loc,
248+
query_lens=query_lens,
249+
seq_lens=seq_lens,
250+
seq_lens_list=seq_lens_list,
224251
max_query_len=num_scheduled_tokens.max(),
225-
common_prefix_len=0,
226-
common_attn_metadata=common_attn_metadata,
227-
)
252+
slot_mapping=slot_mapping,
253+
attn_mask=attn_mask,
254+
attn_state=attn_state,
255+
enable_dbo_across_dp=False,
256+
is_only_prefill=False)
228257
else:
229258
raise NotImplementedError(
230259
"Currently we only support building dummy metadata for DecodeOnly state"

0 commit comments

Comments
 (0)