Skip to content

Commit 8bad42e

Browse files
committed
MTP supports V1 scheduler
Signed-off-by: xuyexiong <[email protected]>
1 parent 9537306 commit 8bad42e

File tree

5 files changed

+321
-212
lines changed

5 files changed

+321
-212
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 61 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,13 @@
2424
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
2525
AttentionLayer, AttentionType)
2626
from vllm.attention.backends.utils import CommonAttentionState
27-
from vllm.config import get_current_vllm_config
27+
from vllm.config import VllmConfig, get_current_vllm_config
2828
from vllm.forward_context import ForwardContext, get_forward_context
2929
from vllm.utils import direct_register_custom_op
3030
from vllm.v1.core.sched.output import SchedulerOutput
3131
from vllm.v1.worker.gpu_input_batch import InputBatch
3232

33-
from vllm_ascend.attention.utils import \
34-
AscendCommonAttentionMetadata as CommonAttentionMetadata
33+
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
3534
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
3635
from vllm_ascend.ops.attention import vanilla_chunked_prefill
3736
from vllm_ascend.utils import get_graph_params
@@ -156,39 +155,49 @@ def split_metadata_for_multistream(
156155

157156
class AscendAttentionMetadataBuilder:
158157

159-
def __init__(self, runner):
158+
def __init__(self, vllm_config: VllmConfig, device: torch.device, runner):
159+
self.vllm_config = vllm_config
160+
self.model_config = vllm_config.model_config
161+
self.device = device
160162
self.runner = runner
161163

162164
def reorder_batch(self, input_batch: "InputBatch",
163165
scheduler_output: "SchedulerOutput") -> bool:
164166
return False
165167

166-
def build(self,
167-
num_reqs,
168-
num_actual_tokens,
169-
max_query_len,
170-
common_attn_metadata: CommonAttentionMetadata,
171-
enable_dbo_across_dp: bool = False,
172-
is_only_prefill: bool = False,
173-
*args,
174-
**kwargs):
175-
176-
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
177-
)
178-
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
179-
block_table[:num_reqs])
180-
181-
query_start_loc = common_attn_metadata.query_start_loc
182-
seq_lens = common_attn_metadata.seq_lens
168+
def build(
169+
self,
170+
common_attn_metadata: AscendCommonAttentionMetadata,
171+
):
172+
num_reqs = common_attn_metadata.num_reqs
173+
num_actual_tokens = common_attn_metadata.num_actual_tokens
174+
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
175+
num_reqs
176+
+ 1]
177+
178+
block_table = common_attn_metadata.block_table_tensor
179+
block_table[:num_reqs, :common_attn_metadata.
180+
max_num_blocks_per_req] = (block_table[:num_reqs])
181+
182+
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
183183
# TODO: Refactor these two param to common metadata in runners,
184184
# preparing for the hybrid KV groups feature
185-
query_lens = common_attn_metadata.query_lens or self.runner.query_lens
185+
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
186186
# Since FIA for GQA is not active now, we temporarily silence it
187187
seq_lens_list = common_attn_metadata.seq_lens_list
188188

189-
slot_mapping = self.runner.slot_mapping[:num_actual_tokens]
190-
attn_mask = self.runner.attn_mask
191-
attn_state = self.runner.attn_state
189+
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
190+
num_actual_tokens].to(
191+
self.device,
192+
non_blocking=
193+
True)
194+
attn_mask = common_attn_metadata.attn_mask
195+
attn_state = common_attn_metadata.attn_state
196+
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
197+
num_reqs
198+
+ 1]
199+
query_start_loc = query_start_loc_cpu.to(self.device,
200+
non_blocking=True)
192201

193202
attn_metadata = AscendMetadata(
194203
num_actual_tokens=num_actual_tokens,
@@ -197,34 +206,50 @@ def build(self,
197206
query_lens=query_lens,
198207
seq_lens=seq_lens,
199208
seq_lens_list=seq_lens_list,
200-
max_query_len=max_query_len,
209+
max_query_len=common_attn_metadata.max_query_len,
201210
slot_mapping=slot_mapping,
202211
attn_mask=attn_mask,
203212
attn_state=attn_state,
204-
enable_dbo_across_dp=enable_dbo_across_dp,
205-
is_only_prefill=is_only_prefill)
213+
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
214+
is_only_prefill=common_attn_metadata.is_only_prefill)
206215
return attn_metadata
207216

208217
def build_dummy_metadata(self, num_actual_tokens, num_reqs,
209218
num_scheduled_tokens, attn_state):
210219
if attn_state == AscendAttentionState.DecodeOnly:
211220
# NOTE: We only need to pay attention to seq_lens_list and block_table here
212-
common_attn_metadata = CommonAttentionMetadata(
213-
seq_lens=torch.empty_like(self.runner.seq_lens_cpu).fill_(2))
214-
215221
block_table = self.runner.input_batch.block_table[0].block_table
216222
block_table[:num_reqs, 0] = torch.arange(1,
217223
num_reqs + 1,
218224
device=block_table.device,
219225
dtype=block_table.dtype)
226+
block_table = self.runner.input_batch.block_table[
227+
0].get_device_tensor()
228+
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
229+
block_table[:num_reqs])
220230

221-
attn_metadata = self.build(
222-
num_reqs=num_reqs,
231+
query_start_loc = None
232+
seq_lens = torch.empty_like(self.runner.seq_lens_cpu).fill_(2)
233+
query_lens = self.runner.query_lens
234+
seq_lens_list = None
235+
236+
slot_mapping = self.runner.slot_mapping[:num_actual_tokens]
237+
attn_mask = self.runner.attn_mask
238+
attn_state = self.runner.attn_state
239+
240+
attn_metadata = AscendMetadata(
223241
num_actual_tokens=num_actual_tokens,
242+
block_tables=block_table,
243+
query_start_loc=query_start_loc,
244+
query_lens=query_lens,
245+
seq_lens=seq_lens,
246+
seq_lens_list=seq_lens_list,
224247
max_query_len=num_scheduled_tokens.max(),
225-
common_prefix_len=0,
226-
common_attn_metadata=common_attn_metadata,
227-
)
248+
slot_mapping=slot_mapping,
249+
attn_mask=attn_mask,
250+
attn_state=attn_state,
251+
enable_dbo_across_dp=False,
252+
is_only_prefill=False)
228253
else:
229254
raise NotImplementedError(
230255
"Currently we only support building dummy metadata for DecodeOnly state"

0 commit comments

Comments
 (0)