Skip to content

Commit 619fff4

Browse files
committed
MTP supports V1 scheduler
Signed-off-by: xuyexiong <[email protected]>
1 parent 17c2884 commit 619fff4

File tree

5 files changed

+322
-208
lines changed

5 files changed

+322
-208
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 63 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
import torch
2323
import torch_npu
24+
import torch.nn as nn
25+
from vllm.config import VllmConfig
2426
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
2527
AttentionLayer, AttentionType)
2628
from vllm.attention.backends.utils import CommonAttentionState
@@ -30,11 +32,10 @@
3032
from vllm.v1.core.sched.output import SchedulerOutput
3133
from vllm.v1.worker.gpu_input_batch import InputBatch
3234

33-
from vllm_ascend.attention.utils import \
34-
AscendCommonAttentionMetadata as CommonAttentionMetadata
3535
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
3636
from vllm_ascend.ops.attention import vanilla_chunked_prefill
3737
from vllm_ascend.utils import get_graph_params
38+
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
3839

3940

4041
class AscendAttentionBackend(AttentionBackend):
@@ -156,39 +157,49 @@ def split_metadata_for_multistream(
156157

157158
class AscendAttentionMetadataBuilder:
158159

159-
def __init__(self, runner):
160+
def __init__(self, vllm_config: VllmConfig, device: torch.device, runner):
161+
self.vllm_config = vllm_config
162+
self.model_config = vllm_config.model_config
163+
self.device = device
160164
self.runner = runner
161165

162166
def reorder_batch(self, input_batch: "InputBatch",
163167
scheduler_output: "SchedulerOutput") -> bool:
164168
return False
165169

166-
def build(self,
167-
num_reqs,
168-
num_actual_tokens,
169-
max_query_len,
170-
common_attn_metadata: CommonAttentionMetadata,
171-
enable_dbo_across_dp: bool = False,
172-
is_only_prefill: bool = False,
173-
*args,
174-
**kwargs):
175-
176-
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
177-
)
178-
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
179-
block_table[:num_reqs])
180-
181-
query_start_loc = common_attn_metadata.query_start_loc
182-
seq_lens = common_attn_metadata.seq_lens
170+
def build(
171+
self,
172+
common_attn_metadata: AscendCommonAttentionMetadata,
173+
):
174+
num_reqs = common_attn_metadata.num_reqs
175+
num_actual_tokens = common_attn_metadata.num_actual_tokens
176+
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
177+
num_reqs
178+
+ 1]
179+
180+
block_table = common_attn_metadata.block_table_tensor
181+
block_table[:num_reqs, :common_attn_metadata.
182+
max_num_blocks_per_req] = (block_table[:num_reqs])
183+
184+
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
183185
# TODO: Refactor these two param to common metadata in runners,
184186
# preparing for the hybrid KV groups feature
185-
query_lens = common_attn_metadata.query_lens or self.runner.query_lens
187+
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
186188
# Since FIA for GQA is not active now, we temporarily silence it
187189
seq_lens_list = common_attn_metadata.seq_lens_list
188190

189-
slot_mapping = self.runner.slot_mapping[:num_actual_tokens]
190-
attn_mask = self.runner.attn_mask
191-
attn_state = self.runner.attn_state
191+
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
192+
num_actual_tokens].to(
193+
self.device,
194+
non_blocking=
195+
True)
196+
attn_mask = common_attn_metadata.attn_mask
197+
attn_state = common_attn_metadata.attn_state
198+
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
199+
num_reqs
200+
+ 1]
201+
query_start_loc = query_start_loc_cpu.to(self.device,
202+
non_blocking=True)
192203

193204
attn_metadata = AscendMetadata(
194205
num_actual_tokens=num_actual_tokens,
@@ -197,34 +208,53 @@ def build(self,
197208
query_lens=query_lens,
198209
seq_lens=seq_lens,
199210
seq_lens_list=seq_lens_list,
200-
max_query_len=max_query_len,
211+
max_query_len=common_attn_metadata.max_query_len,
201212
slot_mapping=slot_mapping,
202213
attn_mask=attn_mask,
203214
attn_state=attn_state,
204-
enable_dbo_across_dp=enable_dbo_across_dp,
205-
is_only_prefill=is_only_prefill)
215+
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
216+
is_only_prefill=common_attn_metadata.is_only_prefill)
206217
return attn_metadata
207218

208219
def build_dummy_metadata(self, num_actual_tokens, num_reqs,
209220
num_scheduled_tokens, attn_state):
210221
if attn_state == AscendAttentionState.DecodeOnly:
211222
# NOTE: We only need to pay attention to seq_lens_list and block_table here
212-
common_attn_metadata = CommonAttentionMetadata(
223+
common_attn_metadata = AscendCommonAttentionMetadata(
213224
seq_lens=torch.empty_like(self.runner.seq_lens_cpu).fill_(2))
214225

215226
block_table = self.runner.input_batch.block_table[0].block_table
216227
block_table[:num_reqs, 0] = torch.arange(1,
217228
num_reqs + 1,
218229
device=block_table.device,
219230
dtype=block_table.dtype)
231+
block_table = self.runner.input_batch.block_table[
232+
0].get_device_tensor()
233+
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
234+
block_table[:num_reqs])
220235

221-
attn_metadata = self.build(
222-
num_reqs=num_reqs,
236+
query_start_loc = common_attn_metadata.query_start_loc
237+
seq_lens = common_attn_metadata.seq_lens
238+
query_lens = self.runner.query_lens
239+
seq_lens_list = None
240+
241+
slot_mapping = self.runner.slot_mapping[:num_actual_tokens]
242+
attn_mask = self.runner.attn_mask
243+
attn_state = self.runner.attn_state
244+
245+
attn_metadata = AscendMetadata(
223246
num_actual_tokens=num_actual_tokens,
247+
block_tables=block_table,
248+
query_start_loc=query_start_loc,
249+
query_lens=query_lens,
250+
seq_lens=seq_lens,
251+
seq_lens_list=seq_lens_list,
224252
max_query_len=num_scheduled_tokens.max(),
225-
common_prefix_len=0,
226-
common_attn_metadata=common_attn_metadata,
227-
)
253+
slot_mapping=slot_mapping,
254+
attn_mask=attn_mask,
255+
attn_state=attn_state,
256+
enable_dbo_across_dp=False,
257+
is_only_prefill=False)
228258
else:
229259
raise NotImplementedError(
230260
"Currently we only support building dummy metadata for DecodeOnly state"

0 commit comments

Comments
 (0)