diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index d289bb4578..4bf4745d30 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -207,6 +207,28 @@ def build( query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1] + + if common_attn_metadata.graph_pad_size > num_actual_tokens: + padded_num_tokens = common_attn_metadata.graph_pad_size - num_actual_tokens + seq_lens = torch.cat([ + seq_lens, + torch.ones(padded_num_tokens, + dtype=seq_lens.dtype, + device=seq_lens.device) + ]) + block_table_padding = torch.zeros( + (padded_num_tokens, ) + block_table.shape[1:], + dtype=block_table.dtype, + device=block_table.device) + block_table = torch.cat([block_table, block_table_padding], dim=0) + query_start_loc_cpu = torch.cat([ + query_start_loc_cpu, + torch.arange(query_start_loc_cpu[-1] + 1, + query_start_loc_cpu[-1] + padded_num_tokens, + dtype=query_start_loc_cpu.dtype, + device=query_start_loc_cpu.device) + ]) + query_start_loc = query_start_loc_cpu.to(self.device, non_blocking=True) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 373a73e297..36a02076bc 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1619,7 +1619,10 @@ def _build_attn_state(self, num_reqs, num_scheduled_tokens, return attn_state def _update_graph_pad_size(self, with_prefill, graph_pad_size): - self.graph_pad_size = -1 + if not with_prefill: + self.graph_pad_size = graph_pad_size + else: + self.graph_pad_size = -1 def _update_input_ids_and_positions(self, input_ids, positions, num_input_tokens, with_prefill,