5
5
6
6
from vllm .attention import AttentionMetadata , get_attn_backend
7
7
from vllm .config import (DeviceConfig , LoadConfig , LoRAConfig , ModelConfig ,
8
- ParallelConfig , SchedulerConfig )
8
+ ParallelConfig , SchedulerConfig , VisionLanguageConfig )
9
9
from vllm .distributed import broadcast_tensor_dict
10
10
from vllm .logger import init_logger
11
11
from vllm .model_executor import SamplingMetadata
@@ -29,6 +29,7 @@ def __init__(
29
29
device_config : DeviceConfig ,
30
30
load_config : LoadConfig ,
31
31
lora_config : Optional [LoRAConfig ],
32
+ vision_language_config : Optional [VisionLanguageConfig ],
32
33
kv_cache_dtype : Optional [str ] = "auto" ,
33
34
is_driver_worker : bool = False ,
34
35
* args ,
@@ -38,6 +39,7 @@ def __init__(
38
39
self .parallel_config = parallel_config
39
40
self .scheduler_config = scheduler_config
40
41
self .lora_config = lora_config
42
+ self .vision_language_config = vision_language_config
41
43
self .load_config = load_config
42
44
self .is_driver_worker = is_driver_worker
43
45
@@ -59,13 +61,14 @@ def __init__(
59
61
self .block_size : int # Set after initial profiling.
60
62
61
63
def load_model (self ) -> None :
62
- self .model = get_model (model_config = self .model_config ,
63
- load_config = self .load_config ,
64
- device_config = self .device_config ,
65
- vision_language_config = None ,
66
- lora_config = self .lora_config ,
67
- parallel_config = self .parallel_config ,
68
- scheduler_config = self .scheduler_config )
64
+ self .model = get_model (
65
+ model_config = self .model_config ,
66
+ load_config = self .load_config ,
67
+ device_config = self .device_config ,
68
+ vision_language_config = self .vision_language_config ,
69
+ lora_config = self .lora_config ,
70
+ parallel_config = self .parallel_config ,
71
+ scheduler_config = self .scheduler_config )
69
72
70
73
def _prepare_prompt (
71
74
self ,
@@ -76,6 +79,7 @@ def _prepare_prompt(
76
79
input_positions : List [int ] = []
77
80
slot_mapping : List [int ] = []
78
81
prompt_lens : List [int ] = []
82
+ multi_modal_input_list : List [torch .Tensor ] = []
79
83
80
84
for seq_group_metadata in seq_group_metadata_list :
81
85
assert seq_group_metadata .is_prompt
@@ -96,6 +100,10 @@ def _prepare_prompt(
96
100
# is always the first token in the sequence.
97
101
input_positions .extend (list (range (computed_len , prompt_len )))
98
102
103
+ if seq_group_metadata .multi_modal_data :
104
+ multi_modal_input_list .append (
105
+ seq_group_metadata .multi_modal_data .data )
106
+
99
107
# Compute the slot mapping.
100
108
block_table = seq_group_metadata .block_tables [seq_id ]
101
109
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
@@ -118,6 +126,15 @@ def _prepare_prompt(
118
126
slot = block_number * self .block_size + block_offset
119
127
slot_mapping .append (slot )
120
128
129
+ if multi_modal_input_list :
130
+ assert self .vision_language_config , (
131
+ "Multi-modal inputs are only supported by "
132
+ "vision language models." )
133
+ multi_modal_input = torch .cat (multi_modal_input_list ,
134
+ dim = 0 ).to (self .device )
135
+ else :
136
+ multi_modal_input = None
137
+
121
138
num_prompt_tokens = len (input_tokens )
122
139
123
140
input_tokens = torch .tensor (input_tokens ,
@@ -144,12 +161,8 @@ def _prepare_prompt(
144
161
slot_mapping = slot_mapping ,
145
162
kv_cache_dtype = self .kv_cache_dtype ,
146
163
)
147
- return (
148
- input_tokens ,
149
- input_positions ,
150
- attn_metadata ,
151
- prompt_lens ,
152
- )
164
+ return (input_tokens , input_positions , attn_metadata , prompt_lens ,
165
+ multi_modal_input )
153
166
154
167
def _prepare_decode (
155
168
self ,
@@ -336,14 +349,16 @@ def prepare_input_tensors(
336
349
seq_group_metadata_list : List [SequenceGroupMetadata ],
337
350
) -> Tuple [torch .Tensor , torch .Tensor , AttentionMetadata ,
338
351
SamplingMetadata ]:
352
+ multi_modal_input = None
339
353
if self .is_driver_worker :
340
354
# NOTE: We assume that all sequences in the group are all prompts or
341
355
# all decodes.
342
356
is_prompt = seq_group_metadata_list [0 ].is_prompt
343
357
# Prepare input tensors.
344
358
if is_prompt :
345
- (input_tokens , input_positions , attn_metadata ,
346
- prompt_lens ) = self ._prepare_prompt (seq_group_metadata_list )
359
+ (input_tokens , input_positions , attn_metadata , prompt_lens ,
360
+ multi_modal_input
361
+ ) = self ._prepare_prompt (seq_group_metadata_list )
347
362
else :
348
363
(input_tokens , input_positions ,
349
364
attn_metadata ) = self ._prepare_decode (seq_group_metadata_list )
@@ -376,20 +391,17 @@ def prepare_input_tensors(
376
391
perform_sampling = False ,
377
392
)
378
393
379
- return (
380
- input_tokens ,
381
- input_positions ,
382
- attn_metadata ,
383
- sampling_metadata ,
384
- )
394
+ return (input_tokens , input_positions , attn_metadata ,
395
+ sampling_metadata , multi_modal_input )
385
396
386
397
@torch .inference_mode ()
387
398
def execute_model (
388
399
self ,
389
400
seq_group_metadata_list : List [SequenceGroupMetadata ],
390
401
kv_caches : List [torch .Tensor ],
391
402
) -> Optional [SamplerOutput ]:
392
- (input_tokens , input_positions , attn_metadata , sampling_metadata
403
+ (input_tokens , input_positions , attn_metadata , sampling_metadata ,
404
+ multi_modal_input
393
405
) = self .prepare_input_tensors (seq_group_metadata_list )
394
406
395
407
model_executable = self .model
@@ -399,6 +411,8 @@ def execute_model(
399
411
"kv_caches" : kv_caches ,
400
412
"attn_metadata" : attn_metadata ,
401
413
}
414
+ if self .vision_language_config :
415
+ execute_model_kwargs .update ({"image_input" : multi_modal_input })
402
416
403
417
hidden_states = model_executable (** execute_model_kwargs )
404
418
0 commit comments