41
41
from vllm .multimodal .utils import group_mm_inputs_by_modality
42
42
from vllm .pooling_params import PoolingParams , PoolingTask
43
43
from vllm .sampling_params import SamplingType
44
- from vllm .sequence import IntermediateTensors
44
+ from vllm .sequence import IntermediateTensors , PoolerOutput
45
45
from vllm .utils import (STR_DTYPE_TO_TORCH_DTYPE , DeviceMemoryProfiler ,
46
46
GiB_bytes , LazyLoader , check_use_alibi , get_dtype_size ,
47
47
is_pin_memory_available , round_up )
@@ -1819,7 +1819,7 @@ def load_model(self, eep_scale_up: bool = False) -> None:
1819
1819
old_global_expert_indices = None
1820
1820
rank_mapping = None
1821
1821
1822
- with DeviceMemoryProfiler () as m : # noqa: SIM117
1822
+ with DeviceMemoryProfiler () as m :
1823
1823
time_before_load = time .perf_counter ()
1824
1824
model_loader = get_model_loader (self .load_config )
1825
1825
if not hasattr (self , "model" ):
@@ -2215,12 +2215,11 @@ def _dummy_sampler_run(
2215
2215
)
2216
2216
return sampler_output
2217
2217
2218
- @torch .inference_mode ()
2219
- def _dummy_pooler_run (
2218
+ def _dummy_pooler_run_task (
2220
2219
self ,
2221
2220
hidden_states : torch .Tensor ,
2222
- ) -> torch . Tensor :
2223
-
2221
+ task : PoolingTask ,
2222
+ ) -> PoolerOutput :
2224
2223
num_tokens = hidden_states .shape [0 ]
2225
2224
max_num_reqs = self .scheduler_config .max_num_seqs
2226
2225
num_reqs = min (num_tokens , max_num_reqs )
@@ -2232,37 +2231,55 @@ def _dummy_pooler_run(
2232
2231
2233
2232
hidden_states_list = list (
2234
2233
torch .split (hidden_states , num_scheduled_tokens_list ))
2235
-
2236
2234
req_num_tokens = num_tokens // num_reqs
2237
2235
2238
- model = cast (VllmModelForPooling , self .model )
2239
- dummy_task = self .get_supported_pooling_tasks ()[0 ]
2240
- dummy_pooling_params = PoolingParams (task = dummy_task )
2236
+ dummy_prompt_lens = torch .tensor (
2237
+ [h .shape [0 ] for h in hidden_states_list ],
2238
+ device = self .device ,
2239
+ )
2240
+ dummy_token_ids = torch .zeros ((num_reqs , req_num_tokens ),
2241
+ dtype = torch .int32 ,
2242
+ device = self .device )
2241
2243
2242
- to_update = model .pooler .get_pooling_updates (dummy_task )
2244
+ model = cast (VllmModelForPooling , self .model )
2245
+ dummy_pooling_params = PoolingParams (task = task )
2246
+ to_update = model .pooler .get_pooling_updates (task )
2243
2247
to_update .apply (dummy_pooling_params )
2244
2248
2245
2249
dummy_metadata = PoolingMetadata (
2246
- prompt_lens = torch .tensor ([h .shape [0 ] for h in hidden_states_list ],
2247
- device = self .device ),
2248
- prompt_token_ids = torch .zeros ((num_reqs , req_num_tokens ),
2249
- dtype = torch .int32 ,
2250
- device = self .device ),
2251
- pooling_params = [dummy_pooling_params ] * num_reqs )
2250
+ prompt_lens = dummy_prompt_lens ,
2251
+ prompt_token_ids = dummy_token_ids ,
2252
+ pooling_params = [dummy_pooling_params ] * num_reqs ,
2253
+ )
2252
2254
2253
2255
try :
2254
- pooler_output = model .pooler (hidden_states = hidden_states_list ,
2255
- pooling_metadata = dummy_metadata )
2256
+ return model .pooler (hidden_states = hidden_states_list ,
2257
+ pooling_metadata = dummy_metadata )
2256
2258
except RuntimeError as e :
2257
2259
if 'out of memory' in str (e ):
2258
2260
raise RuntimeError (
2259
- "CUDA out of memory occurred when warming up pooler with "
2260
- f"{ num_reqs } dummy requests. Please try lowering "
2261
- "`max_num_seqs` or `gpu_memory_utilization` when "
2261
+ "CUDA out of memory occurred when warming up pooler "
2262
+ f"( { task = } ) with { num_reqs } dummy requests. Please try "
2263
+ "lowering `max_num_seqs` or `gpu_memory_utilization` when "
2262
2264
"initializing the engine." ) from e
2263
2265
else :
2264
2266
raise e
2265
- return pooler_output
2267
+
2268
+ @torch .inference_mode ()
2269
+ def _dummy_pooler_run (
2270
+ self ,
2271
+ hidden_states : torch .Tensor ,
2272
+ ) -> PoolerOutput :
2273
+ # Find the task that has the largest output for subsequent steps
2274
+ output_size = dict [PoolingTask , float ]()
2275
+ for task in self .get_supported_pooling_tasks ():
2276
+ # Run a full batch with each task to ensure none of them OOMs
2277
+ output = self ._dummy_pooler_run_task (hidden_states , task )
2278
+ output_size [task ] = output .get_data_nbytes ()
2279
+ del output # Allow GC
2280
+
2281
+ max_task = max (output_size .items (), key = lambda x : x [1 ])[0 ]
2282
+ return self ._dummy_pooler_run_task (hidden_states , max_task )
2266
2283
2267
2284
def profile_run (self ) -> None :
2268
2285
# Profile with multimodal encoder & encoder cache.
0 commit comments