Skip to content

Commit be66ea7

Browse files
committed
refact attn metadata build
Signed-off-by: weiguihua2 <[email protected]>
1 parent c416bbf commit be66ea7

File tree

5 files changed

+6
-22
lines changed

5 files changed

+6
-22
lines changed

vllm_ascend/attention/utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,8 @@ class AscendCommonAttentionMetadata:
3939
spec_attn_mask: torch.Tensor = None
4040
attn_state: AscendAttentionState = None
4141

42-
decode_token_per_req: int
43-
44-
max_num_blocks_per_req: int
45-
42+
max_query_len: int
43+
4644
enable_dbo_across_dp: bool = False
4745

4846
is_only_prefill: bool = False

vllm_ascend/torchair/torchair_model_runner.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ def _build_attention_metadata(self, with_prefill, num_reqs, skip_attn):
7777
attn_mask=self.attn_mask,
7878
spec_attn_mask=self.spec_attn_mask,
7979
attn_state=self.attn_state,
80-
decode_token_per_req=self.decode_token_per_req,
8180
)
8281
attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy(common_attn_metadata)
8382
else:

vllm_ascend/worker/eagle_proposer_v1.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,20 +129,17 @@ def propose(
129129
common_attn_metadata = AscendCommonAttentionMetadata(
130130
query_start_loc=self.runner.query_start_loc[:batch_size + 1],
131131
query_start_loc_cpu=self.query_start_loc_cpu[:batch_size + 1],
132-
seq_lens=self.runner.seq_lens,
133132
seq_lens_cpu=self.runner.seq_lens_cpu,
133+
max_query_len=max_query_len,
134134
num_reqs=batch_size,
135135
num_actual_tokens=num_tokens,
136-
max_query_len=max_query_len,
137136
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
138137
block_table_tensor=self.runner.input_batch.block_table[0].get_device_tensor(),
139-
slot_mapping_cpu=self.runner.slot_mapping_cpu,
140-
positions=self.positions,
138+
slot_mapping_cpu=target_slot_mapping,
139+
positions=target_positions,
141140
attn_mask=self.runner.attn_mask,
142141
spec_attn_mask=self.runner.spec_attn_mask,
143142
attn_state=self.runner.attn_state,
144-
decode_token_per_req=self.runner.decode_token_per_req,
145-
max_num_blocks_per_req=self.runner.max_num_blocks_per_req,
146143
)
147144
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
148145
attn_metadata = self.runner.attn_metadata_builder.build(common_attn_metadata)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -801,19 +801,17 @@ def get_eagle_atten_dict(
801801
common_attn_metadata = AscendCommonAttentionMetadata(
802802
query_start_loc=self.query_start_loc[:num_reqs + 1],
803803
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
804-
seq_lens=self.seq_lens,
805804
seq_lens_cpu=self.seq_lens_cpu,
806805
num_reqs=num_reqs,
807-
num_actual_tokens=total_num_scheduled_tokens,
808806
max_query_len=max_num_scheduled_tokens,
807+
num_actual_tokens=total_num_scheduled_tokens,
809808
actual_seq_lengths_q=self.actual_seq_lengths_q,
810809
block_table_tensor=self.input_batch.block_table[0].get_device_tensor(),
811810
slot_mapping_cpu=self.slot_mapping_cpu,
812811
positions=self.positions,
813812
attn_mask=self.attn_mask,
814813
spec_attn_mask=self.spec_attn_mask,
815814
attn_state=self.attn_state,
816-
decode_token_per_req=self.decode_token_per_req,
817815
max_num_blocks_per_req=self.max_num_blocks_per_req,
818816
)
819817
attn_metadata_i = self.attn_metadata_builder.build(common_attn_metadata)
@@ -1223,20 +1221,16 @@ def _process_reqs(
12231221
common_attn_metadata = AscendCommonAttentionMetadata(
12241222
query_start_loc=self.query_start_loc[:num_reqs + 1],
12251223
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
1226-
seq_lens=self.seq_lens,
12271224
seq_lens_cpu=self.seq_lens_cpu,
12281225
num_reqs=num_reqs,
12291226
num_actual_tokens=total_num_scheduled_tokens,
1230-
max_query_len=max_num_scheduled_tokens,
12311227
actual_seq_lengths_q=self.actual_seq_lengths_q,
12321228
block_table_tensor=self.input_batch.block_table[0].get_device_tensor(),
12331229
slot_mapping_cpu=self.slot_mapping_cpu,
12341230
positions=self.positions,
12351231
attn_mask=self.attn_mask,
12361232
spec_attn_mask=self.spec_attn_mask,
12371233
attn_state=self.attn_state,
1238-
decode_token_per_req=self.decode_token_per_req,
1239-
max_num_blocks_per_req=self.max_num_blocks_per_req,
12401234
enable_dbo_across_dp=enable_dbo,
12411235
is_only_prefill=is_only_prefill,
12421236
graph_pad_size=self.graph_pad_size

vllm_ascend/worker/mtp_proposer_v1.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,6 @@ def propose(
170170
common_attn_metadata = AscendCommonAttentionMetadata(
171171
query_start_loc=self.runner.query_start_loc[:batch_size + 1],
172172
query_start_loc_cpu=self.runner.query_start_loc_cpu[:batch_size + 1],
173-
seq_lens=target_positions[last_token_indices] + 1,
174173
seq_lens_cpu=target_positions.cpu()[last_token_indices] + 1,
175174
num_reqs=batch_size,
176175
num_actual_tokens=num_tokens,
@@ -182,8 +181,6 @@ def propose(
182181
attn_mask=self.runner.attn_mask,
183182
spec_attn_mask=self.runner.spec_attn_mask,
184183
attn_state=self.runner.attn_state,
185-
decode_token_per_req=self.runner.decode_token_per_req,
186-
max_num_blocks_per_req=self.runner.max_num_blocks_per_req,
187184
graph_pad_size=extra_builder_kwargs['graph_pad_size']
188185
)
189186
attn_metadata = self.runner.attn_metadata_builder.build(common_attn_metadata)
@@ -302,7 +299,6 @@ def dummy_run(self,
302299
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
303300
attn_mask=self.runner.attn_mask,
304301
spec_attn_mask=self.runner.spec_attn_mask,
305-
decode_token_per_req=self.runner.decode_token_per_req,
306302
)
307303
attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(common_attn_metadata)
308304

0 commit comments

Comments
 (0)