Skip to content

Commit b557136

Browse files
authored
[Misc] Move fast prefill logic to separate method (#24013)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent acc1a6e commit b557136

File tree

1 file changed

+27
-21
lines changed

1 file changed

+27
-21
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -783,28 +783,8 @@ def _prepare_inputs(
783783

784784
logits_indices_padded = None
785785
if self.cache_config.kv_sharing_fast_prefill:
786-
assert self.kv_sharing_fast_prefill_logits_indices is not None
787-
num_logits = logits_indices.shape[0]
788-
assert num_logits > 0
789-
self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_(
786+
logits_indices_padded = self._prepare_kv_sharing_fast_prefill(
790787
logits_indices)
791-
# There might have leftover indices in logits_indices[num_logits:]
792-
# from previous iterations, whose values may be greater than the
793-
# batch size in the current iteration. To ensure indices are always
794-
# valid, we fill the padded indices with the last index.
795-
self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_(
796-
logits_indices[-1].item())
797-
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
798-
and num_logits <= self.cudagraph_batch_sizes[-1]):
799-
# Use piecewise CUDA graphs.
800-
# Add padding to the batch size.
801-
num_logits_padded = self.vllm_config.pad_for_cudagraph(
802-
num_logits)
803-
else:
804-
num_logits_padded = num_logits
805-
logits_indices_padded = (
806-
self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]
807-
)
808788

809789
attn_metadata: dict[str, Any] = {}
810790

@@ -1109,6 +1089,32 @@ def _calc_spec_decode_metadata(
11091089
)
11101090
return metadata
11111091

1092+
def _prepare_kv_sharing_fast_prefill(
1093+
self,
1094+
logits_indices: torch.Tensor,
1095+
) -> torch.Tensor:
1096+
assert self.kv_sharing_fast_prefill_logits_indices is not None
1097+
num_logits = logits_indices.shape[0]
1098+
assert num_logits > 0
1099+
self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_(
1100+
logits_indices)
1101+
# There might have leftover indices in logits_indices[num_logits:]
1102+
# from previous iterations, whose values may be greater than the
1103+
# batch size in the current iteration. To ensure indices are always
1104+
# valid, we fill the padded indices with the last index.
1105+
self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_(
1106+
logits_indices[-1].item())
1107+
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
1108+
and num_logits <= self.cudagraph_batch_sizes[-1]):
1109+
# Use piecewise CUDA graphs.
1110+
# Add padding to the batch size.
1111+
num_logits_padded = self.vllm_config.pad_for_cudagraph(num_logits)
1112+
else:
1113+
num_logits_padded = num_logits
1114+
logits_indices_padded = (
1115+
self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded])
1116+
return logits_indices_padded
1117+
11121118
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
11131119
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
11141120
if not scheduled_encoder_inputs:

0 commit comments

Comments
 (0)