Skip to content

Commit 0dd3f4f

Browse files
authored
[Misc] Minor refactoring for prepare_inputs (#23116)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent 498259c commit 0dd3f4f

File tree

1 file changed

+21
-22
lines changed

1 file changed

+21
-22
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -757,10 +757,19 @@ def _prepare_inputs(
757757
# Prepare the attention metadata.
758758
self.query_start_loc_np[0] = 0
759759
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
760+
# Note: pad query_start_loc to be non-decreasing, as kernels
761+
# like FlashAttention requires that
762+
self.query_start_loc_np[num_reqs + 1:].fill(cu_num_tokens[-1])
763+
self.query_start_loc.copy_(self.query_start_loc_cpu, non_blocking=True)
764+
query_start_loc = self.query_start_loc[:num_reqs + 1]
760765

761766
self.seq_lens_np[:num_reqs] = (
762767
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
763768
num_scheduled_tokens)
769+
# Fill unused with 0 for full cuda graph mode.
770+
self.seq_lens_np[num_reqs:].fill(0)
771+
self.seq_lens.copy_(self.seq_lens_cpu, non_blocking=True)
772+
seq_lens = self.seq_lens[:num_reqs]
764773

765774
# Copy the tensors to the GPU.
766775
self.input_ids[:total_num_scheduled_tokens].copy_(
@@ -776,22 +785,6 @@ def _prepare_inputs(
776785
self.positions_cpu[:total_num_scheduled_tokens],
777786
non_blocking=True)
778787

779-
self.query_start_loc[:num_reqs + 1].copy_(
780-
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
781-
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
782-
non_blocking=True)
783-
784-
# Fill unused with 0 for full cuda graph mode.
785-
self.seq_lens[num_reqs:].fill_(0)
786-
# Note: pad query_start_loc to be non-decreasing, as kernels
787-
# like FlashAttention requires that
788-
self.query_start_loc[num_reqs + 1:].fill_(
789-
self.query_start_loc_cpu[num_reqs].item())
790-
791-
query_start_loc = self.query_start_loc[:num_reqs + 1]
792-
793-
spec_decode_common_attn_metadata = None
794-
795788
use_spec_decode = len(
796789
scheduler_output.scheduled_spec_decode_tokens) > 0
797790
if not use_spec_decode:
@@ -860,6 +853,13 @@ def _prepare_inputs(
860853
per_layer_metadata[layer_name]
861854
attn_metadata[layer_name] = encoder_attn_metadata
862855

856+
# Used in the below loop.
857+
query_start_loc_cpu = self.query_start_loc_cpu[:num_reqs + 1]
858+
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
859+
num_computed_tokens_cpu = (
860+
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
861+
spec_decode_common_attn_metadata = None
862+
863863
# Prepare the attention metadata for each KV cache group and make layers
864864
# in the same group share the same metadata.
865865
for kv_cache_group_id, kv_cache_group_spec in enumerate(
@@ -874,12 +874,11 @@ def _prepare_inputs(
874874
blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
875875

876876
common_attn_metadata = CommonAttentionMetadata(
877-
query_start_loc=self.query_start_loc[:num_reqs + 1],
878-
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
879-
seq_lens=self.seq_lens[:num_reqs],
880-
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
881-
num_computed_tokens_cpu=self.input_batch.
882-
num_computed_tokens_cpu_tensor[:num_reqs],
877+
query_start_loc=query_start_loc,
878+
query_start_loc_cpu=query_start_loc_cpu,
879+
seq_lens=seq_lens,
880+
seq_lens_cpu=seq_lens_cpu,
881+
num_computed_tokens_cpu=num_computed_tokens_cpu,
883882
num_reqs=num_reqs,
884883
num_actual_tokens=total_num_scheduled_tokens,
885884
max_query_len=max_num_scheduled_tokens,

0 commit comments

Comments
 (0)