51
51
from vllm .model_executor .models .interfaces import supports_transcription
52
52
from vllm .model_executor .models .interfaces_base import (
53
53
VllmModelForPooling , is_pooling_model , is_text_generation_model )
54
- from vllm .multimodal .inputs import MultiModalKwargs , PlaceholderRange
55
- from vllm .multimodal .utils import group_mm_inputs_by_modality
54
+ from vllm .multimodal .inputs import MultiModalKwargsItem , PlaceholderRange
55
+ from vllm .multimodal .utils import group_mm_kwargs_by_modality
56
56
from vllm .pooling_params import PoolingParams
57
57
from vllm .sampling_params import SamplingType
58
58
from vllm .sequence import IntermediateTensors
@@ -479,7 +479,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
479
479
self .requests [req_id ] = CachedRequestState (
480
480
req_id = req_id ,
481
481
prompt_token_ids = new_req_data .prompt_token_ids ,
482
- mm_inputs = new_req_data .mm_inputs ,
482
+ mm_kwargs = new_req_data .mm_kwargs ,
483
483
mm_positions = new_req_data .mm_positions ,
484
484
sampling_params = sampling_params ,
485
485
pooling_params = new_req_data .pooling_params ,
@@ -497,18 +497,19 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
497
497
second_per_grid_ts = []
498
498
audio_feature_lengths = []
499
499
use_audio_in_video = False
500
- for mm_input in self .requests [req_id ].mm_inputs :
500
+ for item in self .requests [req_id ].mm_kwargs :
501
+ mm_input = item .require_data ()
501
502
if mm_input .get ("image_grid_thw" ) is not None :
502
- image_grid_thw .extend (
503
+ image_grid_thw .append (
503
504
mm_input ["image_grid_thw" ].tolist ())
504
505
if mm_input .get ("video_grid_thw" ) is not None :
505
- video_grid_thw .extend (
506
+ video_grid_thw .append (
506
507
mm_input ["video_grid_thw" ].tolist ())
507
508
if mm_input .get ("second_per_grid_ts" ) is not None :
508
- second_per_grid_ts .extend (
509
+ second_per_grid_ts .append (
509
510
mm_input ["second_per_grid_ts" ])
510
511
if mm_input .get ("audio_feature_lengths" ) is not None :
511
- audio_feature_lengths .extend (
512
+ audio_feature_lengths .append (
512
513
mm_input ["audio_feature_lengths" ])
513
514
if mm_input .get ("use_audio_in_video" ) is True :
514
515
use_audio_in_video = True
@@ -912,13 +913,13 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
912
913
return
913
914
914
915
# Batch the multi-modal inputs.
915
- mm_inputs = list [MultiModalKwargs ]()
916
+ mm_kwargs = list [MultiModalKwargsItem ]()
916
917
req_ids_pos = list [tuple [str , int , PlaceholderRange ]]()
917
918
for req_id , encoder_input_ids in scheduled_encoder_inputs .items ():
918
919
req_state = self .requests [req_id ]
919
920
920
921
for mm_input_id in encoder_input_ids :
921
- mm_inputs .append (req_state .mm_inputs [mm_input_id ])
922
+ mm_kwargs .append (req_state .mm_kwargs [mm_input_id ])
922
923
req_ids_pos .append (
923
924
(req_id , mm_input_id , req_state .mm_positions [mm_input_id ]))
924
925
@@ -929,14 +930,13 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
929
930
# in the same batch while still being able to benefit from batching
930
931
# multimodal inputs. The proper solution should be reordering the
931
932
# encoder outputs.
932
- grouped_mm_inputs_list = group_mm_inputs_by_modality (mm_inputs )
933
933
934
934
encoder_outputs = []
935
- for grouped_mm_inputs in grouped_mm_inputs_list :
936
- batched_mm_inputs = MultiModalKwargs . batch ( grouped_mm_inputs )
937
- batched_mm_inputs = MultiModalKwargs . as_kwargs ( batched_mm_inputs ,
938
- device = self . device )
939
-
935
+ for _ , num_items , mm_kwargs_group in group_mm_kwargs_by_modality (
936
+ mm_kwargs ,
937
+ device = self . device ,
938
+ pin_memory = True ,
939
+ ):
940
940
# Run the encoder.
941
941
# `curr_group_outputs` is either of the following:
942
942
# 1. A tensor of shape (num_items, feature_size, hidden_size)
@@ -945,11 +945,11 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
945
945
# (feature_size, hidden_size) in case the feature size is dynamic
946
946
# depending on the input multimodal items.
947
947
curr_group_outputs = self .model .get_multimodal_embeddings (
948
- ** batched_mm_inputs )
948
+ ** mm_kwargs_group )
949
949
950
950
sanity_check_mm_encoder_outputs (
951
951
curr_group_outputs ,
952
- expected_num_items = len ( grouped_mm_inputs ) ,
952
+ expected_num_items = num_items ,
953
953
)
954
954
955
955
for output in curr_group_outputs :
0 commit comments