Skip to content

Commit 811ac13

Browse files
[Core] Factor out common logic for MM budget calculation (#22228)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent e79a12f commit 811ac13

File tree

3 files changed

+306
-223
lines changed

3 files changed

+306
-223
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 108 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@
3636
from vllm.model_executor.models.interfaces_base import (
3737
VllmModelForPooling, is_pooling_model, is_text_generation_model)
3838
from vllm.multimodal import MULTIMODAL_REGISTRY
39-
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
39+
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
40+
PlaceholderRange)
4041
from vllm.multimodal.utils import group_mm_inputs_by_modality
4142
from vllm.pooling_params import PoolingParams
4243
from vllm.sampling_params import SamplingType
@@ -51,7 +52,6 @@
5152
make_kv_sharing_fast_prefill_attention_metadata,
5253
make_local_attention_virtual_batches,
5354
reorder_batch_to_split_decodes_and_prefills)
54-
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
5555
from vllm.v1.kv_cache_interface import (AttentionSpec,
5656
ChunkedLocalAttentionSpec,
5757
FullAttentionSpec, KVCacheConfig,
@@ -73,7 +73,7 @@
7373
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
7474

7575
from ..sample.logits_processor import LogitsProcessorManager
76-
from .utils import (bind_kv_cache, gather_mm_placeholders,
76+
from .utils import (MultiModalBudget, bind_kv_cache, gather_mm_placeholders,
7777
initialize_kv_cache_for_kv_sharing,
7878
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
7979

@@ -148,14 +148,6 @@ def __init__(
148148
self.mm_registry = MULTIMODAL_REGISTRY
149149
self.uses_mrope = model_config.uses_mrope
150150

151-
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
152-
model_config=model_config,
153-
scheduler_config=scheduler_config,
154-
mm_registry=self.mm_registry,
155-
)
156-
self.max_num_encoder_input_tokens = encoder_compute_budget
157-
self.encoder_cache_size = encoder_cache_size
158-
159151
# Sampler
160152
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
161153

@@ -330,6 +322,14 @@ def __init__(
330322
self.kv_sharing_fast_prefill_logits_indices = torch.zeros(
331323
self.max_num_tokens, dtype=torch.int32, device=self.device)
332324

325+
self.mm_budget = (MultiModalBudget(
326+
self.model_config,
327+
self.scheduler_config,
328+
self.mm_registry,
329+
max_model_len=self.max_model_len,
330+
max_num_reqs=self.max_num_reqs,
331+
) if self.is_multimodal_model else None)
332+
333333
self.reorder_batch_threshold: Optional[int] = None
334334

335335
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
@@ -578,37 +578,33 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
578578
# Refresh batch metadata with any pending updates.
579579
self.input_batch.refresh_metadata()
580580

581-
def _init_model_kwargs_for_multimodal_model(
581+
def _extract_mm_kwargs(
582582
self,
583-
scheduler_output: Optional["SchedulerOutput"] = None,
584-
num_reqs: int = -1,
585-
) -> dict[str, Any]:
586-
587-
model_kwargs: dict[str, Any] = {}
588-
if self.is_multimodal_raw_input_supported:
589-
# This model requires the raw multimodal data in input.
583+
scheduler_output: "SchedulerOutput",
584+
) -> BatchedTensorInputs:
585+
if self.is_multimodal_raw_input_supported: # noqa: SIM102
590586
if scheduler_output:
591-
multi_modal_kwargs_list = []
587+
multi_modal_kwargs_list = list[MultiModalKwargs]()
592588
for req in scheduler_output.scheduled_new_reqs:
593589
req_mm_inputs = req.mm_inputs
594590
if not isinstance(req_mm_inputs, list):
595591
req_mm_inputs = list(req_mm_inputs)
596592
multi_modal_kwargs_list.extend(req_mm_inputs)
597-
multi_modal_kwargs = MultiModalKwargs.batch(
598-
multi_modal_kwargs_list)
599-
else:
600-
# The only case where SchedulerOutput is None is for
601-
# a dummy run let's get some dummy data.
602-
dummy_data = [
603-
self.mm_registry.get_decoder_dummy_data(
604-
model_config=self.model_config,
605-
seq_len=1).multi_modal_data for i in range(num_reqs)
606-
]
607-
multi_modal_kwargs = MultiModalKwargs.batch(dummy_data)
608593

609-
model_kwargs.update(multi_modal_kwargs)
594+
return MultiModalKwargs.batch(multi_modal_kwargs_list)
610595

611-
return model_kwargs
596+
return {}
597+
598+
def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs:
599+
if self.is_multimodal_raw_input_supported:
600+
mm_budget = self.mm_budget
601+
assert mm_budget is not None
602+
603+
dummy_modality, _ = mm_budget.get_modality_with_max_tokens()
604+
605+
return self._get_mm_dummy_batch(dummy_modality, num_seqs)
606+
607+
return {}
612608

613609
def _get_cumsum_and_arange(
614610
self,
@@ -1517,27 +1513,26 @@ def execute_model(
15171513
# NOTE(woosuk): To unify token ids and soft tokens (vision
15181514
# embeddings), we always use embeddings (rather than token ids)
15191515
# as input to the multimodal model, even when the input is text.
1520-
input_ids = self.input_ids[:num_scheduled_tokens]
1521-
1522-
model_kwargs = self._init_model_kwargs_for_multimodal_model(
1523-
scheduler_output=scheduler_output)
1524-
inputs_embeds = self.model.get_input_embeddings(
1525-
input_ids=input_ids,
1516+
inputs_embeds_scheduled = self.model.get_input_embeddings(
1517+
input_ids=self.input_ids[:num_scheduled_tokens],
15261518
multimodal_embeddings=mm_embeds or None,
15271519
)
15281520

15291521
# TODO(woosuk): Avoid the copy. Optimize.
1530-
self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
1531-
inputs_embeds = self.inputs_embeds[:num_input_tokens]
1522+
self.inputs_embeds[:num_scheduled_tokens].copy_(
1523+
inputs_embeds_scheduled)
1524+
15321525
input_ids = None
1526+
inputs_embeds = self.inputs_embeds[:num_input_tokens]
1527+
model_mm_kwargs = self._extract_mm_kwargs(scheduler_output)
15331528
else:
15341529
# For text-only models, we use token ids as input.
15351530
# While it is possible to use embeddings as input just like the
15361531
# multimodal models, it is not desirable for performance since
15371532
# then the embedding layer is not included in the CUDA graph.
15381533
input_ids = self.input_ids[:num_input_tokens]
15391534
inputs_embeds = None
1540-
model_kwargs = {}
1535+
model_mm_kwargs = {}
15411536
if self.uses_mrope:
15421537
positions = self.mrope_positions[:, :num_input_tokens]
15431538
else:
@@ -1571,7 +1566,7 @@ def execute_model(
15711566
intermediate_tensors=intermediate_tensors,
15721567
inputs_embeds=inputs_embeds,
15731568
**MultiModalKwargs.as_kwargs(
1574-
model_kwargs,
1569+
model_mm_kwargs,
15751570
device=self.device,
15761571
),
15771572
)
@@ -2149,6 +2144,30 @@ def rand_input_ids() -> torch.Tensor:
21492144
yield
21502145
input_ids.fill_(0)
21512146

2147+
def _get_mm_dummy_batch(
2148+
self,
2149+
modality: str,
2150+
max_items_per_batch: int,
2151+
) -> BatchedTensorInputs:
2152+
"""Dummy data for profiling and precompiling multimodal models."""
2153+
dummy_decoder_data = self.mm_registry.get_decoder_dummy_data(
2154+
model_config=self.model_config,
2155+
seq_len=self.max_num_tokens,
2156+
mm_counts={modality: 1},
2157+
)
2158+
dummy_mm_data = dummy_decoder_data.multi_modal_data
2159+
2160+
# Result in the maximum GPU consumption of the model
2161+
dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0)
2162+
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
2163+
2164+
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
2165+
max_items_per_batch)
2166+
return MultiModalKwargs.as_kwargs(
2167+
batched_dummy_mm_inputs,
2168+
device=self.device,
2169+
)
2170+
21522171
@torch.inference_mode()
21532172
def _dummy_run(
21542173
self,
@@ -2213,16 +2232,14 @@ def _dummy_run(
22132232

22142233
with self.maybe_dummy_run_with_lora(self.lora_config,
22152234
num_scheduled_tokens):
2216-
model = self.model
22172235
if self.is_multimodal_model:
2218-
model_kwargs = self._init_model_kwargs_for_multimodal_model(
2219-
num_reqs=num_reqs)
22202236
input_ids = None
22212237
inputs_embeds = self.inputs_embeds[:num_tokens]
2238+
model_mm_kwargs = self._dummy_mm_kwargs(num_reqs)
22222239
else:
22232240
input_ids = self.input_ids[:num_tokens]
22242241
inputs_embeds = None
2225-
model_kwargs = {}
2242+
model_mm_kwargs = {}
22262243

22272244
if self.uses_mrope:
22282245
positions = self.mrope_positions[:, :num_tokens]
@@ -2247,13 +2264,13 @@ def _dummy_run(
22472264
self.vllm_config,
22482265
num_tokens=num_tokens,
22492266
num_tokens_across_dp=num_tokens_across_dp):
2250-
outputs = model(
2267+
outputs = self.model(
22512268
input_ids=input_ids,
22522269
positions=positions,
22532270
intermediate_tensors=intermediate_tensors,
22542271
inputs_embeds=inputs_embeds,
22552272
**MultiModalKwargs.as_kwargs(
2256-
model_kwargs,
2273+
model_mm_kwargs,
22572274
device=self.device,
22582275
),
22592276
)
@@ -2423,75 +2440,51 @@ def _dummy_pooler_run(
24232440

24242441
def profile_run(self) -> None:
24252442
# Profile with multimodal encoder & encoder cache.
2426-
# TODO: handle encoder-decoder models once we support them.
2427-
if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
2428-
and self.encoder_cache_size > 0):
2429-
2430-
# NOTE: Currently model is profiled with a single non-text
2431-
# modality with the max possible input tokens even when
2432-
# it supports multiple.
2433-
max_tokens_by_modality_dict = self.mm_registry \
2434-
.get_max_tokens_per_item_by_nonzero_modality(self.model_config)
2435-
dummy_data_modality, max_tokens_per_mm_item = max(
2436-
max_tokens_by_modality_dict.items(), key=lambda item: item[1])
2437-
2438-
# Check how many items of this modality can be supported by
2439-
# the encoder budget.
2440-
encoder_budget = min(self.max_num_encoder_input_tokens,
2441-
self.encoder_cache_size)
2442-
2443-
max_num_mm_items_encoder_budget = encoder_budget // \
2444-
max_tokens_per_mm_item
2445-
2446-
# Check how many items of this modality can be supported by
2447-
# the decoder budget.
2448-
max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt(
2449-
self.model_config)[dummy_data_modality]
2450-
2451-
# NOTE: We do not consider max_num_batched_tokens on purpose
2452-
# because the multimodal embeddings can be generated in advance
2453-
# and chunked prefilled.
2454-
max_num_mm_items_decoder_budget = self.max_num_reqs * \
2455-
max_mm_items_per_req
2456-
2457-
max_num_mm_items = max(
2458-
1,
2459-
min(max_num_mm_items_encoder_budget,
2460-
max_num_mm_items_decoder_budget))
2461-
2462-
logger.info(
2463-
"Encoder cache will be initialized with a budget of %s tokens,"
2464-
" and profiled with %s %s items of the maximum feature size.",
2465-
encoder_budget, max_num_mm_items, dummy_data_modality)
2466-
2467-
# Create dummy batch of multimodal inputs.
2468-
dummy_mm_kwargs = self.mm_registry.get_decoder_dummy_data(
2469-
model_config=self.model_config,
2470-
seq_len=max_tokens_per_mm_item,
2471-
mm_counts={
2472-
dummy_data_modality: 1
2473-
},
2474-
).multi_modal_data
2475-
2476-
batched_dummy_mm_inputs = MultiModalKwargs.batch(
2477-
[dummy_mm_kwargs] * max_num_mm_items,
2478-
pin_memory=self.pin_memory)
2479-
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
2480-
batched_dummy_mm_inputs,
2481-
device=self.device,
2482-
)
2443+
if self.is_multimodal_model:
2444+
mm_budget = self.mm_budget
2445+
assert mm_budget is not None
2446+
2447+
# TODO: handle encoder-decoder models once we support them.
2448+
if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
2449+
# NOTE: Currently model is profiled with a single non-text
2450+
# modality with the max possible input tokens even when
2451+
# it supports multiple.
2452+
(
2453+
dummy_modality,
2454+
max_tokens,
2455+
) = mm_budget.get_modality_with_max_tokens()
2456+
(
2457+
max_mm_items_per_prompt,
2458+
max_mm_items_per_batch,
2459+
) = mm_budget.get_max_items(dummy_modality, max_tokens)
2460+
2461+
logger.info(
2462+
"Encoder cache will be initialized with a budget of "
2463+
"%s tokens, and profiled with %s %s items of the maximum "
2464+
"feature size.",
2465+
encoder_budget,
2466+
max_mm_items_per_batch,
2467+
dummy_modality,
2468+
)
24832469

2484-
# Run multimodal encoder.
2485-
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
2486-
**batched_dummy_mm_inputs)
2470+
# Create dummy batch of multimodal inputs.
2471+
batched_dummy_mm_inputs = self._get_mm_dummy_batch(
2472+
dummy_modality,
2473+
max_mm_items_per_batch,
2474+
)
24872475

2488-
sanity_check_mm_encoder_outputs(
2489-
dummy_encoder_outputs,
2490-
expected_num_items=max_num_mm_items,
2491-
)
2476+
# Run multimodal encoder.
2477+
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
2478+
**batched_dummy_mm_inputs)
2479+
2480+
sanity_check_mm_encoder_outputs(
2481+
dummy_encoder_outputs,
2482+
expected_num_items=max_mm_items_per_batch,
2483+
)
24922484

2493-
# Cache the dummy encoder outputs.
2494-
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
2485+
# Cache the dummy encoder outputs.
2486+
self.encoder_cache["tmp"] = dict(
2487+
enumerate(dummy_encoder_outputs))
24952488

24962489
# Add `is_profile` here to pre-allocate communication buffers
24972490
hidden_states, last_hidden_states \

0 commit comments

Comments
 (0)