Skip to content

Commit f59ec35

Browse files
[V1] Check all pooling tasks during profiling (#21299)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 2671334 commit f59ec35

File tree

2 files changed

+47
-23
lines changed

2 files changed

+47
-23
lines changed

vllm/sequence.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,6 +1173,10 @@ class PoolingSequenceGroupOutput(
11731173
# The actual type is in SequenceGroup.pooled_data
11741174
data: Any
11751175

1176+
def get_data_nbytes(self) -> int:
1177+
data: torch.Tensor = self.data
1178+
return data.nbytes
1179+
11761180
def __repr__(self) -> str:
11771181
return f"PoolingSequenceGroupOutput(data={self.data}"
11781182

@@ -1234,6 +1238,9 @@ class PoolerOutput(
12341238
"""The output from a pooling operation in the pooling model."""
12351239
outputs: list[PoolingSequenceGroupOutput]
12361240

1241+
def get_data_nbytes(self) -> int:
1242+
return sum(o.get_data_nbytes() for o in self.outputs)
1243+
12371244
def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
12381245
return self.outputs[idx]
12391246

vllm/v1/worker/gpu_model_runner.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from vllm.multimodal.utils import group_mm_inputs_by_modality
4242
from vllm.pooling_params import PoolingParams, PoolingTask
4343
from vllm.sampling_params import SamplingType
44-
from vllm.sequence import IntermediateTensors
44+
from vllm.sequence import IntermediateTensors, PoolerOutput
4545
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
4646
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
4747
is_pin_memory_available, round_up)
@@ -1819,7 +1819,7 @@ def load_model(self, eep_scale_up: bool = False) -> None:
18191819
old_global_expert_indices = None
18201820
rank_mapping = None
18211821

1822-
with DeviceMemoryProfiler() as m: # noqa: SIM117
1822+
with DeviceMemoryProfiler() as m:
18231823
time_before_load = time.perf_counter()
18241824
model_loader = get_model_loader(self.load_config)
18251825
if not hasattr(self, "model"):
@@ -2215,12 +2215,11 @@ def _dummy_sampler_run(
22152215
)
22162216
return sampler_output
22172217

2218-
@torch.inference_mode()
2219-
def _dummy_pooler_run(
2218+
def _dummy_pooler_run_task(
22202219
self,
22212220
hidden_states: torch.Tensor,
2222-
) -> torch.Tensor:
2223-
2221+
task: PoolingTask,
2222+
) -> PoolerOutput:
22242223
num_tokens = hidden_states.shape[0]
22252224
max_num_reqs = self.scheduler_config.max_num_seqs
22262225
num_reqs = min(num_tokens, max_num_reqs)
@@ -2232,37 +2231,55 @@ def _dummy_pooler_run(
22322231

22332232
hidden_states_list = list(
22342233
torch.split(hidden_states, num_scheduled_tokens_list))
2235-
22362234
req_num_tokens = num_tokens // num_reqs
22372235

2238-
model = cast(VllmModelForPooling, self.model)
2239-
dummy_task = self.get_supported_pooling_tasks()[0]
2240-
dummy_pooling_params = PoolingParams(task=dummy_task)
2236+
dummy_prompt_lens = torch.tensor(
2237+
[h.shape[0] for h in hidden_states_list],
2238+
device=self.device,
2239+
)
2240+
dummy_token_ids = torch.zeros((num_reqs, req_num_tokens),
2241+
dtype=torch.int32,
2242+
device=self.device)
22412243

2242-
to_update = model.pooler.get_pooling_updates(dummy_task)
2244+
model = cast(VllmModelForPooling, self.model)
2245+
dummy_pooling_params = PoolingParams(task=task)
2246+
to_update = model.pooler.get_pooling_updates(task)
22432247
to_update.apply(dummy_pooling_params)
22442248

22452249
dummy_metadata = PoolingMetadata(
2246-
prompt_lens=torch.tensor([h.shape[0] for h in hidden_states_list],
2247-
device=self.device),
2248-
prompt_token_ids=torch.zeros((num_reqs, req_num_tokens),
2249-
dtype=torch.int32,
2250-
device=self.device),
2251-
pooling_params=[dummy_pooling_params] * num_reqs)
2250+
prompt_lens=dummy_prompt_lens,
2251+
prompt_token_ids=dummy_token_ids,
2252+
pooling_params=[dummy_pooling_params] * num_reqs,
2253+
)
22522254

22532255
try:
2254-
pooler_output = model.pooler(hidden_states=hidden_states_list,
2255-
pooling_metadata=dummy_metadata)
2256+
return model.pooler(hidden_states=hidden_states_list,
2257+
pooling_metadata=dummy_metadata)
22562258
except RuntimeError as e:
22572259
if 'out of memory' in str(e):
22582260
raise RuntimeError(
2259-
"CUDA out of memory occurred when warming up pooler with "
2260-
f"{num_reqs} dummy requests. Please try lowering "
2261-
"`max_num_seqs` or `gpu_memory_utilization` when "
2261+
"CUDA out of memory occurred when warming up pooler "
2262+
f"({task=}) with {num_reqs} dummy requests. Please try "
2263+
"lowering `max_num_seqs` or `gpu_memory_utilization` when "
22622264
"initializing the engine.") from e
22632265
else:
22642266
raise e
2265-
return pooler_output
2267+
2268+
@torch.inference_mode()
2269+
def _dummy_pooler_run(
2270+
self,
2271+
hidden_states: torch.Tensor,
2272+
) -> PoolerOutput:
2273+
# Find the task that has the largest output for subsequent steps
2274+
output_size = dict[PoolingTask, float]()
2275+
for task in self.get_supported_pooling_tasks():
2276+
# Run a full batch with each task to ensure none of them OOMs
2277+
output = self._dummy_pooler_run_task(hidden_states, task)
2278+
output_size[task] = output.get_data_nbytes()
2279+
del output # Allow GC
2280+
2281+
max_task = max(output_size.items(), key=lambda x: x[1])[0]
2282+
return self._dummy_pooler_run_task(hidden_states, max_task)
22662283

22672284
def profile_run(self) -> None:
22682285
# Profile with multimodal encoder & encoder cache.

0 commit comments

Comments
 (0)