Skip to content

Commit ef3b644

Browse files
committed
refact attn metadata build
Signed-off-by: weiguihua2 <[email protected]>
1 parent 52f99e6 commit ef3b644

File tree

4 files changed

+7
-7
lines changed

4 files changed

+7
-7
lines changed

vllm_ascend/attention/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def split_decodes_and_prefills(
106106
return num_reqs, 0, num_tokens, 0
107107

108108
first_prefill = is_prefill.int().argmax(dim=-1).item()
109-
assert torch.all(query_lens[first_prefill:] > decode_threshold)
109+
assert torch.all(query_lens[first_prefill:] >= decode_threshold)
110110
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
111111
num_decodes = first_prefill
112112
num_prefills = num_reqs - num_decodes

vllm_ascend/worker/eagle_proposer_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def propose(
128128

129129
common_attn_metadata = AscendCommonAttentionMetadata(
130130
query_start_loc=self.runner.query_start_loc[:batch_size + 1],
131-
query_start_loc_cpu=self.query_start_loc_cpu[:batch_size + 1],
131+
query_start_loc_cpu=self.runner.query_start_loc_cpu[:batch_size + 1],
132132
seq_lens_cpu=self.runner.seq_lens_cpu,
133133
max_query_len=max_query_len,
134134
num_reqs=batch_size,

vllm_ascend/worker/model_runner_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -815,7 +815,7 @@ def get_eagle_atten_dict(
815815
max_num_blocks_per_req=self.max_num_blocks_per_req,
816816
decode_token_per_req=self.decode_token_per_req,
817817
)
818-
attn_metadata_i = self.attn_metadata_builder.build(common_attn_metadata, self.model)
818+
attn_metadata_i = self.attn_metadata_builder.build(common_attn_metadata, self.get_model())
819819
for layer_name in kv_cache_group_spec.layer_names:
820820
attn_metadata[layer_name] = attn_metadata_i
821821

vllm_ascend/worker/mtp_proposer_v1.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,9 @@ def propose(
168168
num_input_tokens = num_tokens
169169

170170
common_attn_metadata = AscendCommonAttentionMetadata(
171-
query_start_loc=self.runner.query_start_loc[:batch_size + 1],
172-
query_start_loc_cpu=self.runner.query_start_loc_cpu[:batch_size + 1],
173-
seq_lens_cpu=target_positions.cpu()[last_token_indices] + 1,
171+
query_start_loc=cu_num_tokens[:batch_size + 1],
172+
query_start_loc_cpu=cu_num_tokens[:batch_size + 1].cpu(),
173+
seq_lens_cpu=seq_lens.cpu(),
174174
num_reqs=batch_size,
175175
num_actual_tokens=num_tokens,
176176
max_query_len=max_query_len,
@@ -184,7 +184,7 @@ def propose(
184184
graph_pad_size=extra_builder_kwargs['graph_pad_size'],
185185
decode_token_per_req=self.runner.decode_token_per_req,
186186
)
187-
attn_metadata = self.runner.attn_metadata_builder.build(common_attn_metadata, self.runner.model)
187+
attn_metadata = self.runner.attn_metadata_builder.build(common_attn_metadata, self.runner.get_model())
188188

189189
self.positions[:num_tokens] = target_positions
190190
self.hidden_states[:num_tokens] = target_hidden_states

0 commit comments

Comments
 (0)