@@ -783,28 +783,8 @@ def _prepare_inputs(
783783
784784 logits_indices_padded = None
785785 if self .cache_config .kv_sharing_fast_prefill :
786- assert self .kv_sharing_fast_prefill_logits_indices is not None
787- num_logits = logits_indices .shape [0 ]
788- assert num_logits > 0
789- self .kv_sharing_fast_prefill_logits_indices [:num_logits ].copy_ (
786+ logits_indices_padded = self ._prepare_kv_sharing_fast_prefill (
790787 logits_indices )
791- # There might have leftover indices in logits_indices[num_logits:]
792- # from previous iterations, whose values may be greater than the
793- # batch size in the current iteration. To ensure indices are always
794- # valid, we fill the padded indices with the last index.
795- self .kv_sharing_fast_prefill_logits_indices [num_logits :].fill_ (
796- logits_indices [- 1 ].item ())
797- if (self .compilation_config .cudagraph_mode != CUDAGraphMode .NONE
798- and num_logits <= self .cudagraph_batch_sizes [- 1 ]):
799- # Use piecewise CUDA graphs.
800- # Add padding to the batch size.
801- num_logits_padded = self .vllm_config .pad_for_cudagraph (
802- num_logits )
803- else :
804- num_logits_padded = num_logits
805- logits_indices_padded = (
806- self .kv_sharing_fast_prefill_logits_indices [:num_logits_padded ]
807- )
808788
809789 attn_metadata : dict [str , Any ] = {}
810790
@@ -1109,6 +1089,32 @@ def _calc_spec_decode_metadata(
11091089 )
11101090 return metadata
11111091
1092+ def _prepare_kv_sharing_fast_prefill (
1093+ self ,
1094+ logits_indices : torch .Tensor ,
1095+ ) -> torch .Tensor :
1096+ assert self .kv_sharing_fast_prefill_logits_indices is not None
1097+ num_logits = logits_indices .shape [0 ]
1098+ assert num_logits > 0
1099+ self .kv_sharing_fast_prefill_logits_indices [:num_logits ].copy_ (
1100+ logits_indices )
1101+ # There might have leftover indices in logits_indices[num_logits:]
1102+ # from previous iterations, whose values may be greater than the
1103+ # batch size in the current iteration. To ensure indices are always
1104+ # valid, we fill the padded indices with the last index.
1105+ self .kv_sharing_fast_prefill_logits_indices [num_logits :].fill_ (
1106+ logits_indices [- 1 ].item ())
1107+ if (self .compilation_config .cudagraph_mode != CUDAGraphMode .NONE
1108+ and num_logits <= self .cudagraph_batch_sizes [- 1 ]):
1109+ # Use piecewise CUDA graphs.
1110+ # Add padding to the batch size.
1111+ num_logits_padded = self .vllm_config .pad_for_cudagraph (num_logits )
1112+ else :
1113+ num_logits_padded = num_logits
1114+ logits_indices_padded = (
1115+ self .kv_sharing_fast_prefill_logits_indices [:num_logits_padded ])
1116+ return logits_indices_padded
1117+
11121118 def _execute_mm_encoder (self , scheduler_output : "SchedulerOutput" ):
11131119 scheduled_encoder_inputs = scheduler_output .scheduled_encoder_inputs
11141120 if not scheduled_encoder_inputs :
0 commit comments