36
36
from vllm .model_executor .models .interfaces_base import (
37
37
VllmModelForPooling , is_pooling_model , is_text_generation_model )
38
38
from vllm .multimodal import MULTIMODAL_REGISTRY
39
- from vllm .multimodal .inputs import MultiModalKwargs , PlaceholderRange
39
+ from vllm .multimodal .inputs import (BatchedTensorInputs , MultiModalKwargs ,
40
+ PlaceholderRange )
40
41
from vllm .multimodal .utils import group_mm_inputs_by_modality
41
42
from vllm .pooling_params import PoolingParams
42
43
from vllm .sampling_params import SamplingType
51
52
make_kv_sharing_fast_prefill_attention_metadata ,
52
53
make_local_attention_virtual_batches ,
53
54
reorder_batch_to_split_decodes_and_prefills )
54
- from vllm .v1 .core .encoder_cache_manager import compute_encoder_budget
55
55
from vllm .v1 .kv_cache_interface import (AttentionSpec ,
56
56
ChunkedLocalAttentionSpec ,
57
57
FullAttentionSpec , KVCacheConfig ,
73
73
from vllm .v1 .worker .lora_model_runner_mixin import LoRAModelRunnerMixin
74
74
75
75
from ..sample .logits_processor import LogitsProcessorManager
76
- from .utils import (bind_kv_cache , gather_mm_placeholders ,
76
+ from .utils import (MultiModalBudget , bind_kv_cache , gather_mm_placeholders ,
77
77
initialize_kv_cache_for_kv_sharing ,
78
78
sanity_check_mm_encoder_outputs , scatter_mm_placeholders )
79
79
@@ -148,14 +148,6 @@ def __init__(
148
148
self .mm_registry = MULTIMODAL_REGISTRY
149
149
self .uses_mrope = model_config .uses_mrope
150
150
151
- encoder_compute_budget , encoder_cache_size = compute_encoder_budget (
152
- model_config = model_config ,
153
- scheduler_config = scheduler_config ,
154
- mm_registry = self .mm_registry ,
155
- )
156
- self .max_num_encoder_input_tokens = encoder_compute_budget
157
- self .encoder_cache_size = encoder_cache_size
158
-
159
151
# Sampler
160
152
self .sampler = Sampler (logprobs_mode = self .model_config .logprobs_mode )
161
153
@@ -330,6 +322,14 @@ def __init__(
330
322
self .kv_sharing_fast_prefill_logits_indices = torch .zeros (
331
323
self .max_num_tokens , dtype = torch .int32 , device = self .device )
332
324
325
+ self .mm_budget = (MultiModalBudget (
326
+ self .model_config ,
327
+ self .scheduler_config ,
328
+ self .mm_registry ,
329
+ max_model_len = self .max_model_len ,
330
+ max_num_reqs = self .max_num_reqs ,
331
+ ) if self .is_multimodal_model else None )
332
+
333
333
self .reorder_batch_threshold : Optional [int ] = None
334
334
335
335
def _may_reorder_batch (self , scheduler_output : "SchedulerOutput" ) -> None :
@@ -578,37 +578,33 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
578
578
# Refresh batch metadata with any pending updates.
579
579
self .input_batch .refresh_metadata ()
580
580
581
- def _init_model_kwargs_for_multimodal_model (
581
+ def _extract_mm_kwargs (
582
582
self ,
583
- scheduler_output : Optional ["SchedulerOutput" ] = None ,
584
- num_reqs : int = - 1 ,
585
- ) -> dict [str , Any ]:
586
-
587
- model_kwargs : dict [str , Any ] = {}
588
- if self .is_multimodal_raw_input_supported :
589
- # This model requires the raw multimodal data in input.
583
+ scheduler_output : "SchedulerOutput" ,
584
+ ) -> BatchedTensorInputs :
585
+ if self .is_multimodal_raw_input_supported : # noqa: SIM102
590
586
if scheduler_output :
591
- multi_modal_kwargs_list = []
587
+ multi_modal_kwargs_list = list [ MultiModalKwargs ]()
592
588
for req in scheduler_output .scheduled_new_reqs :
593
589
req_mm_inputs = req .mm_inputs
594
590
if not isinstance (req_mm_inputs , list ):
595
591
req_mm_inputs = list (req_mm_inputs )
596
592
multi_modal_kwargs_list .extend (req_mm_inputs )
597
- multi_modal_kwargs = MultiModalKwargs .batch (
598
- multi_modal_kwargs_list )
599
- else :
600
- # The only case where SchedulerOutput is None is for
601
- # a dummy run let's get some dummy data.
602
- dummy_data = [
603
- self .mm_registry .get_decoder_dummy_data (
604
- model_config = self .model_config ,
605
- seq_len = 1 ).multi_modal_data for i in range (num_reqs )
606
- ]
607
- multi_modal_kwargs = MultiModalKwargs .batch (dummy_data )
608
593
609
- model_kwargs . update ( multi_modal_kwargs )
594
+ return MultiModalKwargs . batch ( multi_modal_kwargs_list )
610
595
611
- return model_kwargs
596
+ return {}
597
+
598
+ def _dummy_mm_kwargs (self , num_seqs : int ) -> BatchedTensorInputs :
599
+ if self .is_multimodal_raw_input_supported :
600
+ mm_budget = self .mm_budget
601
+ assert mm_budget is not None
602
+
603
+ dummy_modality , _ = mm_budget .get_modality_with_max_tokens ()
604
+
605
+ return self ._get_mm_dummy_batch (dummy_modality , num_seqs )
606
+
607
+ return {}
612
608
613
609
def _get_cumsum_and_arange (
614
610
self ,
@@ -1517,27 +1513,26 @@ def execute_model(
1517
1513
# NOTE(woosuk): To unify token ids and soft tokens (vision
1518
1514
# embeddings), we always use embeddings (rather than token ids)
1519
1515
# as input to the multimodal model, even when the input is text.
1520
- input_ids = self .input_ids [:num_scheduled_tokens ]
1521
-
1522
- model_kwargs = self ._init_model_kwargs_for_multimodal_model (
1523
- scheduler_output = scheduler_output )
1524
- inputs_embeds = self .model .get_input_embeddings (
1525
- input_ids = input_ids ,
1516
+ inputs_embeds_scheduled = self .model .get_input_embeddings (
1517
+ input_ids = self .input_ids [:num_scheduled_tokens ],
1526
1518
multimodal_embeddings = mm_embeds or None ,
1527
1519
)
1528
1520
1529
1521
# TODO(woosuk): Avoid the copy. Optimize.
1530
- self .inputs_embeds [:num_scheduled_tokens ].copy_ (inputs_embeds )
1531
- inputs_embeds = self .inputs_embeds [:num_input_tokens ]
1522
+ self .inputs_embeds [:num_scheduled_tokens ].copy_ (
1523
+ inputs_embeds_scheduled )
1524
+
1532
1525
input_ids = None
1526
+ inputs_embeds = self .inputs_embeds [:num_input_tokens ]
1527
+ model_mm_kwargs = self ._extract_mm_kwargs (scheduler_output )
1533
1528
else :
1534
1529
# For text-only models, we use token ids as input.
1535
1530
# While it is possible to use embeddings as input just like the
1536
1531
# multimodal models, it is not desirable for performance since
1537
1532
# then the embedding layer is not included in the CUDA graph.
1538
1533
input_ids = self .input_ids [:num_input_tokens ]
1539
1534
inputs_embeds = None
1540
- model_kwargs = {}
1535
+ model_mm_kwargs = {}
1541
1536
if self .uses_mrope :
1542
1537
positions = self .mrope_positions [:, :num_input_tokens ]
1543
1538
else :
@@ -1571,7 +1566,7 @@ def execute_model(
1571
1566
intermediate_tensors = intermediate_tensors ,
1572
1567
inputs_embeds = inputs_embeds ,
1573
1568
** MultiModalKwargs .as_kwargs (
1574
- model_kwargs ,
1569
+ model_mm_kwargs ,
1575
1570
device = self .device ,
1576
1571
),
1577
1572
)
@@ -2149,6 +2144,30 @@ def rand_input_ids() -> torch.Tensor:
2149
2144
yield
2150
2145
input_ids .fill_ (0 )
2151
2146
2147
+ def _get_mm_dummy_batch (
2148
+ self ,
2149
+ modality : str ,
2150
+ max_items_per_batch : int ,
2151
+ ) -> BatchedTensorInputs :
2152
+ """Dummy data for profiling and precompiling multimodal models."""
2153
+ dummy_decoder_data = self .mm_registry .get_decoder_dummy_data (
2154
+ model_config = self .model_config ,
2155
+ seq_len = self .max_num_tokens ,
2156
+ mm_counts = {modality : 1 },
2157
+ )
2158
+ dummy_mm_data = dummy_decoder_data .multi_modal_data
2159
+
2160
+ # Result in the maximum GPU consumption of the model
2161
+ dummy_mm_item = dummy_mm_data .get_item (modality = modality , item_index = 0 )
2162
+ dummy_mm_kwargs = MultiModalKwargs .from_items ([dummy_mm_item ])
2163
+
2164
+ batched_dummy_mm_inputs = MultiModalKwargs .batch ([dummy_mm_kwargs ] *
2165
+ max_items_per_batch )
2166
+ return MultiModalKwargs .as_kwargs (
2167
+ batched_dummy_mm_inputs ,
2168
+ device = self .device ,
2169
+ )
2170
+
2152
2171
@torch .inference_mode ()
2153
2172
def _dummy_run (
2154
2173
self ,
@@ -2213,16 +2232,14 @@ def _dummy_run(
2213
2232
2214
2233
with self .maybe_dummy_run_with_lora (self .lora_config ,
2215
2234
num_scheduled_tokens ):
2216
- model = self .model
2217
2235
if self .is_multimodal_model :
2218
- model_kwargs = self ._init_model_kwargs_for_multimodal_model (
2219
- num_reqs = num_reqs )
2220
2236
input_ids = None
2221
2237
inputs_embeds = self .inputs_embeds [:num_tokens ]
2238
+ model_mm_kwargs = self ._dummy_mm_kwargs (num_reqs )
2222
2239
else :
2223
2240
input_ids = self .input_ids [:num_tokens ]
2224
2241
inputs_embeds = None
2225
- model_kwargs = {}
2242
+ model_mm_kwargs = {}
2226
2243
2227
2244
if self .uses_mrope :
2228
2245
positions = self .mrope_positions [:, :num_tokens ]
@@ -2247,13 +2264,13 @@ def _dummy_run(
2247
2264
self .vllm_config ,
2248
2265
num_tokens = num_tokens ,
2249
2266
num_tokens_across_dp = num_tokens_across_dp ):
2250
- outputs = model (
2267
+ outputs = self . model (
2251
2268
input_ids = input_ids ,
2252
2269
positions = positions ,
2253
2270
intermediate_tensors = intermediate_tensors ,
2254
2271
inputs_embeds = inputs_embeds ,
2255
2272
** MultiModalKwargs .as_kwargs (
2256
- model_kwargs ,
2273
+ model_mm_kwargs ,
2257
2274
device = self .device ,
2258
2275
),
2259
2276
)
@@ -2423,75 +2440,51 @@ def _dummy_pooler_run(
2423
2440
2424
2441
def profile_run (self ) -> None :
2425
2442
# Profile with multimodal encoder & encoder cache.
2426
- # TODO: handle encoder-decoder models once we support them.
2427
- if (self .is_multimodal_model and self .max_num_encoder_input_tokens > 0
2428
- and self .encoder_cache_size > 0 ):
2429
-
2430
- # NOTE: Currently model is profiled with a single non-text
2431
- # modality with the max possible input tokens even when
2432
- # it supports multiple.
2433
- max_tokens_by_modality_dict = self .mm_registry \
2434
- .get_max_tokens_per_item_by_nonzero_modality (self .model_config )
2435
- dummy_data_modality , max_tokens_per_mm_item = max (
2436
- max_tokens_by_modality_dict .items (), key = lambda item : item [1 ])
2437
-
2438
- # Check how many items of this modality can be supported by
2439
- # the encoder budget.
2440
- encoder_budget = min (self .max_num_encoder_input_tokens ,
2441
- self .encoder_cache_size )
2442
-
2443
- max_num_mm_items_encoder_budget = encoder_budget // \
2444
- max_tokens_per_mm_item
2445
-
2446
- # Check how many items of this modality can be supported by
2447
- # the decoder budget.
2448
- max_mm_items_per_req = self .mm_registry .get_mm_limits_per_prompt (
2449
- self .model_config )[dummy_data_modality ]
2450
-
2451
- # NOTE: We do not consider max_num_batched_tokens on purpose
2452
- # because the multimodal embeddings can be generated in advance
2453
- # and chunked prefilled.
2454
- max_num_mm_items_decoder_budget = self .max_num_reqs * \
2455
- max_mm_items_per_req
2456
-
2457
- max_num_mm_items = max (
2458
- 1 ,
2459
- min (max_num_mm_items_encoder_budget ,
2460
- max_num_mm_items_decoder_budget ))
2461
-
2462
- logger .info (
2463
- "Encoder cache will be initialized with a budget of %s tokens,"
2464
- " and profiled with %s %s items of the maximum feature size." ,
2465
- encoder_budget , max_num_mm_items , dummy_data_modality )
2466
-
2467
- # Create dummy batch of multimodal inputs.
2468
- dummy_mm_kwargs = self .mm_registry .get_decoder_dummy_data (
2469
- model_config = self .model_config ,
2470
- seq_len = max_tokens_per_mm_item ,
2471
- mm_counts = {
2472
- dummy_data_modality : 1
2473
- },
2474
- ).multi_modal_data
2475
-
2476
- batched_dummy_mm_inputs = MultiModalKwargs .batch (
2477
- [dummy_mm_kwargs ] * max_num_mm_items ,
2478
- pin_memory = self .pin_memory )
2479
- batched_dummy_mm_inputs = MultiModalKwargs .as_kwargs (
2480
- batched_dummy_mm_inputs ,
2481
- device = self .device ,
2482
- )
2443
+ if self .is_multimodal_model :
2444
+ mm_budget = self .mm_budget
2445
+ assert mm_budget is not None
2446
+
2447
+ # TODO: handle encoder-decoder models once we support them.
2448
+ if (encoder_budget := mm_budget .get_encoder_budget ()) > 0 :
2449
+ # NOTE: Currently model is profiled with a single non-text
2450
+ # modality with the max possible input tokens even when
2451
+ # it supports multiple.
2452
+ (
2453
+ dummy_modality ,
2454
+ max_tokens ,
2455
+ ) = mm_budget .get_modality_with_max_tokens ()
2456
+ (
2457
+ max_mm_items_per_prompt ,
2458
+ max_mm_items_per_batch ,
2459
+ ) = mm_budget .get_max_items (dummy_modality , max_tokens )
2460
+
2461
+ logger .info (
2462
+ "Encoder cache will be initialized with a budget of "
2463
+ "%s tokens, and profiled with %s %s items of the maximum "
2464
+ "feature size." ,
2465
+ encoder_budget ,
2466
+ max_mm_items_per_batch ,
2467
+ dummy_modality ,
2468
+ )
2483
2469
2484
- # Run multimodal encoder.
2485
- dummy_encoder_outputs = self .model .get_multimodal_embeddings (
2486
- ** batched_dummy_mm_inputs )
2470
+ # Create dummy batch of multimodal inputs.
2471
+ batched_dummy_mm_inputs = self ._get_mm_dummy_batch (
2472
+ dummy_modality ,
2473
+ max_mm_items_per_batch ,
2474
+ )
2487
2475
2488
- sanity_check_mm_encoder_outputs (
2489
- dummy_encoder_outputs ,
2490
- expected_num_items = max_num_mm_items ,
2491
- )
2476
+ # Run multimodal encoder.
2477
+ dummy_encoder_outputs = self .model .get_multimodal_embeddings (
2478
+ ** batched_dummy_mm_inputs )
2479
+
2480
+ sanity_check_mm_encoder_outputs (
2481
+ dummy_encoder_outputs ,
2482
+ expected_num_items = max_mm_items_per_batch ,
2483
+ )
2492
2484
2493
- # Cache the dummy encoder outputs.
2494
- self .encoder_cache ["tmp" ] = dict (enumerate (dummy_encoder_outputs ))
2485
+ # Cache the dummy encoder outputs.
2486
+ self .encoder_cache ["tmp" ] = dict (
2487
+ enumerate (dummy_encoder_outputs ))
2495
2488
2496
2489
# Add `is_profile` here to pre-allocate communication buffers
2497
2490
hidden_states , last_hidden_states \
0 commit comments