51
51
from vllm .model_executor .models .interfaces import supports_transcription
52
52
from vllm .model_executor .models .interfaces_base import (
53
53
VllmModelForPooling , is_pooling_model , is_text_generation_model )
54
- from vllm .multimodal .inputs import MultiModalKwargs , PlaceholderRange
55
- from vllm .multimodal .utils import group_mm_inputs_by_modality
56
54
from vllm .pooling_params import PoolingParams
57
55
from vllm .sampling_params import SamplingType
58
56
from vllm .sequence import IntermediateTensors
93
91
from vllm_ascend .worker .npu_input_batch import CachedRequestState , InputBatch
94
92
95
93
if not vllm_version_is ("0.10.0" ):
94
+ from vllm .multimodal .inputs import MultiModalKwargsItem , PlaceholderRange
95
+ from vllm .multimodal .utils import group_mm_kwargs_by_modality
96
96
from vllm .tasks import GenerationTask , SupportedTask
97
97
from vllm .v1 .worker .kv_connector_model_runner_mixin import \
98
98
KVConnectorOutput
99
+ else :
100
+ from vllm .multimodal .inputs import MultiModalKwargs , PlaceholderRange
101
+ from vllm .multimodal .utils import group_mm_inputs_by_modality
99
102
100
103
if TYPE_CHECKING :
101
104
import xgrammar as xgr # type: ignore[import-untyped]
@@ -475,20 +478,34 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
475
478
model = cast (VllmModelForPooling , self .model )
476
479
to_update = model .pooler .get_pooling_updates (task )
477
480
to_update .apply (pooling_params )
478
-
479
- self .requests [req_id ] = CachedRequestState (
480
- req_id = req_id ,
481
- prompt_token_ids = new_req_data .prompt_token_ids ,
482
- mm_inputs = new_req_data .mm_inputs ,
483
- mm_positions = new_req_data .mm_positions ,
484
- sampling_params = sampling_params ,
485
- pooling_params = new_req_data .pooling_params ,
486
- generator = generator ,
487
- block_ids = new_req_data .block_ids ,
488
- num_computed_tokens = new_req_data .num_computed_tokens ,
489
- output_token_ids = [],
490
- lora_request = new_req_data .lora_request ,
491
- )
481
+ if vllm_version_is ("0.10.0" ):
482
+ self .requests [req_id ] = CachedRequestState (
483
+ req_id = req_id ,
484
+ prompt_token_ids = new_req_data .prompt_token_ids ,
485
+ mm_kwargs = new_req_data .mm_inputs ,
486
+ mm_positions = new_req_data .mm_positions ,
487
+ sampling_params = sampling_params ,
488
+ pooling_params = new_req_data .pooling_params ,
489
+ generator = generator ,
490
+ block_ids = new_req_data .block_ids ,
491
+ num_computed_tokens = new_req_data .num_computed_tokens ,
492
+ output_token_ids = [],
493
+ lora_request = new_req_data .lora_request ,
494
+ )
495
+ else :
496
+ self .requests [req_id ] = CachedRequestState (
497
+ req_id = req_id ,
498
+ prompt_token_ids = new_req_data .prompt_token_ids ,
499
+ mm_kwargs = new_req_data .mm_kwargs ,
500
+ mm_positions = new_req_data .mm_positions ,
501
+ sampling_params = sampling_params ,
502
+ pooling_params = new_req_data .pooling_params ,
503
+ generator = generator ,
504
+ block_ids = new_req_data .block_ids ,
505
+ num_computed_tokens = new_req_data .num_computed_tokens ,
506
+ output_token_ids = [],
507
+ lora_request = new_req_data .lora_request ,
508
+ )
492
509
493
510
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
494
511
if self .uses_mrope :
@@ -497,21 +514,39 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
497
514
second_per_grid_ts = []
498
515
audio_feature_lengths = []
499
516
use_audio_in_video = False
500
- for mm_input in self .requests [req_id ].mm_inputs :
501
- if mm_input .get ("image_grid_thw" ) is not None :
502
- image_grid_thw .extend (
503
- mm_input ["image_grid_thw" ].tolist ())
504
- if mm_input .get ("video_grid_thw" ) is not None :
505
- video_grid_thw .extend (
506
- mm_input ["video_grid_thw" ].tolist ())
507
- if mm_input .get ("second_per_grid_ts" ) is not None :
508
- second_per_grid_ts .extend (
509
- mm_input ["second_per_grid_ts" ])
510
- if mm_input .get ("audio_feature_lengths" ) is not None :
511
- audio_feature_lengths .extend (
512
- mm_input ["audio_feature_lengths" ])
513
- if mm_input .get ("use_audio_in_video" ) is True :
514
- use_audio_in_video = True
517
+ if vllm_version_is ("0.10.0" ):
518
+ for mm_input in self .requests [req_id ].mm_kwargs :
519
+ if mm_input .get ("image_grid_thw" ) is not None :
520
+ image_grid_thw .extend (
521
+ mm_input ["image_grid_thw" ].tolist ())
522
+ if mm_input .get ("video_grid_thw" ) is not None :
523
+ video_grid_thw .extend (
524
+ mm_input ["video_grid_thw" ].tolist ())
525
+ if mm_input .get ("second_per_grid_ts" ) is not None :
526
+ second_per_grid_ts .extend (
527
+ mm_input ["second_per_grid_ts" ])
528
+ if mm_input .get ("audio_feature_lengths" ) is not None :
529
+ audio_feature_lengths .extend (
530
+ mm_input ["audio_feature_lengths" ])
531
+ if mm_input .get ("use_audio_in_video" ) is True :
532
+ use_audio_in_video = True
533
+ else :
534
+ for item in self .requests [req_id ].mm_kwargs :
535
+ mm_input = item .require_data ()
536
+ if mm_input .get ("image_grid_thw" ) is not None :
537
+ image_grid_thw .append (
538
+ mm_input ["image_grid_thw" ].tolist ())
539
+ if mm_input .get ("video_grid_thw" ) is not None :
540
+ video_grid_thw .append (
541
+ mm_input ["video_grid_thw" ].tolist ())
542
+ if mm_input .get ("second_per_grid_ts" ) is not None :
543
+ second_per_grid_ts .append (
544
+ mm_input ["second_per_grid_ts" ])
545
+ if mm_input .get ("audio_feature_lengths" ) is not None :
546
+ audio_feature_lengths .append (
547
+ mm_input ["audio_feature_lengths" ])
548
+ if mm_input .get ("use_audio_in_video" ) is True :
549
+ use_audio_in_video = True
515
550
516
551
hf_config = self .model_config .hf_config
517
552
@@ -912,13 +947,16 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
912
947
return
913
948
914
949
# Batch the multi-modal inputs.
915
- mm_inputs = list [MultiModalKwargs ]()
950
+ if vllm_version_is ("0.10.0" ):
951
+ mm_kwargs = list [MultiModalKwargs ]()
952
+ else :
953
+ mm_kwargs = list [MultiModalKwargsItem ]()
916
954
req_ids_pos = list [tuple [str , int , PlaceholderRange ]]()
917
955
for req_id , encoder_input_ids in scheduled_encoder_inputs .items ():
918
956
req_state = self .requests [req_id ]
919
957
920
958
for mm_input_id in encoder_input_ids :
921
- mm_inputs .append (req_state .mm_inputs [mm_input_id ])
959
+ mm_kwargs .append (req_state .mm_kwargs [mm_input_id ])
922
960
req_ids_pos .append (
923
961
(req_id , mm_input_id , req_state .mm_positions [mm_input_id ]))
924
962
@@ -929,31 +967,54 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
929
967
# in the same batch while still being able to benefit from batching
930
968
# multimodal inputs. The proper solution should be reordering the
931
969
# encoder outputs.
932
- grouped_mm_inputs_list = group_mm_inputs_by_modality (mm_inputs )
933
-
934
970
encoder_outputs = []
935
- for grouped_mm_inputs in grouped_mm_inputs_list :
936
- batched_mm_inputs = MultiModalKwargs .batch (grouped_mm_inputs )
937
- batched_mm_inputs = MultiModalKwargs .as_kwargs (batched_mm_inputs ,
938
- device = self .device )
939
-
940
- # Run the encoder.
941
- # `curr_group_outputs` is either of the following:
942
- # 1. A tensor of shape (num_items, feature_size, hidden_size)
943
- # in case feature_size is fixed across all multimodal items.
944
- # 2. A list or tuple (length: num_items) of tensors, each of shape
945
- # (feature_size, hidden_size) in case the feature size is dynamic
946
- # depending on the input multimodal items.
947
- curr_group_outputs = self .model .get_multimodal_embeddings (
948
- ** batched_mm_inputs )
949
-
950
- sanity_check_mm_encoder_outputs (
951
- curr_group_outputs ,
952
- expected_num_items = len (grouped_mm_inputs ),
953
- )
971
+ if vllm_version_is ("0.10.0" ):
972
+ grouped_mm_inputs_list = group_mm_inputs_by_modality (mm_kwargs )
973
+
974
+ for grouped_mm_inputs in grouped_mm_inputs_list :
975
+ batched_mm_inputs = MultiModalKwargs .batch (grouped_mm_inputs )
976
+ batched_mm_inputs = MultiModalKwargs .as_kwargs (
977
+ batched_mm_inputs , device = self .device )
978
+ # Run the encoder.
979
+ # `curr_group_outputs` is either of the following:
980
+ # 1. A tensor of shape (num_items, feature_size, hidden_size)
981
+ # in case feature_size is fixed across all multimodal items.
982
+ # 2. A list or tuple (length: num_items) of tensors, each of shape
983
+ # (feature_size, hidden_size) in case the feature size is dynamic
984
+ # depending on the input multimodal items.
985
+ curr_group_outputs = self .model .get_multimodal_embeddings (
986
+ ** batched_mm_inputs )
987
+
988
+ sanity_check_mm_encoder_outputs (
989
+ curr_group_outputs ,
990
+ expected_num_items = len (grouped_mm_inputs ),
991
+ )
992
+
993
+ for output in curr_group_outputs :
994
+ encoder_outputs .append (output )
995
+ else :
996
+ for _ , num_items , mm_kwargs_group in group_mm_kwargs_by_modality (
997
+ mm_kwargs ,
998
+ device = self .device ,
999
+ pin_memory = True ,
1000
+ ):
1001
+ # Run the encoder.
1002
+ # `curr_group_outputs` is either of the following:
1003
+ # 1. A tensor of shape (num_items, feature_size, hidden_size)
1004
+ # in case feature_size is fixed across all multimodal items.
1005
+ # 2. A list or tuple (length: num_items) of tensors, each of shape
1006
+ # (feature_size, hidden_size) in case the feature size is dynamic
1007
+ # depending on the input multimodal items.
1008
+ curr_group_outputs = self .model .get_multimodal_embeddings (
1009
+ ** mm_kwargs_group )
1010
+
1011
+ sanity_check_mm_encoder_outputs (
1012
+ curr_group_outputs ,
1013
+ expected_num_items = num_items ,
1014
+ )
954
1015
955
- for output in curr_group_outputs :
956
- encoder_outputs .append (output )
1016
+ for output in curr_group_outputs :
1017
+ encoder_outputs .append (output )
957
1018
958
1019
# Cache the encoder outputs.
959
1020
for (req_id , input_id , pos_info ), output in zip (
0 commit comments