Skip to content

Commit 296cdf8

Browse files
authored
[Misc] Add vision language model support to CPU backend (#3968)
1 parent 747b1a7 commit 296cdf8

File tree

3 files changed

+53
-32
lines changed

3 files changed

+53
-32
lines changed

vllm/executor/cpu_executor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def _init_worker(self):
4545
rank=0,
4646
distributed_init_method=distributed_init_method,
4747
lora_config=self.lora_config,
48+
vision_language_config=self.vision_language_config,
4849
kv_cache_dtype=self.cache_config.cache_dtype,
4950
is_driver_worker=True,
5051
)

vllm/worker/cpu_model_runner.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from vllm.attention import AttentionMetadata, get_attn_backend
77
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
8-
ParallelConfig, SchedulerConfig)
8+
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
99
from vllm.distributed import broadcast_tensor_dict
1010
from vllm.logger import init_logger
1111
from vllm.model_executor import SamplingMetadata
@@ -29,6 +29,7 @@ def __init__(
2929
device_config: DeviceConfig,
3030
load_config: LoadConfig,
3131
lora_config: Optional[LoRAConfig],
32+
vision_language_config: Optional[VisionLanguageConfig],
3233
kv_cache_dtype: Optional[str] = "auto",
3334
is_driver_worker: bool = False,
3435
*args,
@@ -38,6 +39,7 @@ def __init__(
3839
self.parallel_config = parallel_config
3940
self.scheduler_config = scheduler_config
4041
self.lora_config = lora_config
42+
self.vision_language_config = vision_language_config
4143
self.load_config = load_config
4244
self.is_driver_worker = is_driver_worker
4345

@@ -59,13 +61,14 @@ def __init__(
5961
self.block_size: int # Set after initial profiling.
6062

6163
def load_model(self) -> None:
62-
self.model = get_model(model_config=self.model_config,
63-
load_config=self.load_config,
64-
device_config=self.device_config,
65-
vision_language_config=None,
66-
lora_config=self.lora_config,
67-
parallel_config=self.parallel_config,
68-
scheduler_config=self.scheduler_config)
64+
self.model = get_model(
65+
model_config=self.model_config,
66+
load_config=self.load_config,
67+
device_config=self.device_config,
68+
vision_language_config=self.vision_language_config,
69+
lora_config=self.lora_config,
70+
parallel_config=self.parallel_config,
71+
scheduler_config=self.scheduler_config)
6972

7073
def _prepare_prompt(
7174
self,
@@ -76,6 +79,7 @@ def _prepare_prompt(
7679
input_positions: List[int] = []
7780
slot_mapping: List[int] = []
7881
prompt_lens: List[int] = []
82+
multi_modal_input_list: List[torch.Tensor] = []
7983

8084
for seq_group_metadata in seq_group_metadata_list:
8185
assert seq_group_metadata.is_prompt
@@ -96,6 +100,10 @@ def _prepare_prompt(
96100
# is always the first token in the sequence.
97101
input_positions.extend(list(range(computed_len, prompt_len)))
98102

103+
if seq_group_metadata.multi_modal_data:
104+
multi_modal_input_list.append(
105+
seq_group_metadata.multi_modal_data.data)
106+
99107
# Compute the slot mapping.
100108
block_table = seq_group_metadata.block_tables[seq_id]
101109
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
@@ -118,6 +126,15 @@ def _prepare_prompt(
118126
slot = block_number * self.block_size + block_offset
119127
slot_mapping.append(slot)
120128

129+
if multi_modal_input_list:
130+
assert self.vision_language_config, (
131+
"Multi-modal inputs are only supported by "
132+
"vision language models.")
133+
multi_modal_input = torch.cat(multi_modal_input_list,
134+
dim=0).to(self.device)
135+
else:
136+
multi_modal_input = None
137+
121138
num_prompt_tokens = len(input_tokens)
122139

123140
input_tokens = torch.tensor(input_tokens,
@@ -144,12 +161,8 @@ def _prepare_prompt(
144161
slot_mapping=slot_mapping,
145162
kv_cache_dtype=self.kv_cache_dtype,
146163
)
147-
return (
148-
input_tokens,
149-
input_positions,
150-
attn_metadata,
151-
prompt_lens,
152-
)
164+
return (input_tokens, input_positions, attn_metadata, prompt_lens,
165+
multi_modal_input)
153166

154167
def _prepare_decode(
155168
self,
@@ -336,14 +349,16 @@ def prepare_input_tensors(
336349
seq_group_metadata_list: List[SequenceGroupMetadata],
337350
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata,
338351
SamplingMetadata]:
352+
multi_modal_input = None
339353
if self.is_driver_worker:
340354
# NOTE: We assume that all sequences in the group are all prompts or
341355
# all decodes.
342356
is_prompt = seq_group_metadata_list[0].is_prompt
343357
# Prepare input tensors.
344358
if is_prompt:
345-
(input_tokens, input_positions, attn_metadata,
346-
prompt_lens) = self._prepare_prompt(seq_group_metadata_list)
359+
(input_tokens, input_positions, attn_metadata, prompt_lens,
360+
multi_modal_input
361+
) = self._prepare_prompt(seq_group_metadata_list)
347362
else:
348363
(input_tokens, input_positions,
349364
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
@@ -376,20 +391,17 @@ def prepare_input_tensors(
376391
perform_sampling=False,
377392
)
378393

379-
return (
380-
input_tokens,
381-
input_positions,
382-
attn_metadata,
383-
sampling_metadata,
384-
)
394+
return (input_tokens, input_positions, attn_metadata,
395+
sampling_metadata, multi_modal_input)
385396

386397
@torch.inference_mode()
387398
def execute_model(
388399
self,
389400
seq_group_metadata_list: List[SequenceGroupMetadata],
390401
kv_caches: List[torch.Tensor],
391402
) -> Optional[SamplerOutput]:
392-
(input_tokens, input_positions, attn_metadata, sampling_metadata
403+
(input_tokens, input_positions, attn_metadata, sampling_metadata,
404+
multi_modal_input
393405
) = self.prepare_input_tensors(seq_group_metadata_list)
394406

395407
model_executable = self.model
@@ -399,6 +411,8 @@ def execute_model(
399411
"kv_caches": kv_caches,
400412
"attn_metadata": attn_metadata,
401413
}
414+
if self.vision_language_config:
415+
execute_model_kwargs.update({"image_input": multi_modal_input})
402416

403417
hidden_states = model_executable(**execute_model_kwargs)
404418

vllm/worker/cpu_worker.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
from vllm.attention import get_attn_backend
88
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
9-
ModelConfig, ParallelConfig, SchedulerConfig)
9+
ModelConfig, ParallelConfig, SchedulerConfig,
10+
VisionLanguageConfig)
1011
from vllm.distributed import (broadcast_tensor_dict,
1112
ensure_model_parallel_initialized,
1213
init_distributed_environment)
@@ -122,6 +123,7 @@ def __init__(
122123
rank: int,
123124
distributed_init_method: str,
124125
lora_config: Optional[LoRAConfig] = None,
126+
vision_language_config: Optional[VisionLanguageConfig] = None,
125127
kv_cache_dtype: Optional[str] = "auto",
126128
is_driver_worker: bool = False,
127129
) -> None:
@@ -135,21 +137,25 @@ def __init__(
135137
self.rank = rank
136138
self.distributed_init_method = distributed_init_method
137139
self.lora_config = lora_config
140+
self.vision_language_config = vision_language_config
138141
self.is_driver_worker = is_driver_worker
139142
if self.is_driver_worker:
140143
assert self.rank == 0, "The driver worker must have rank 0."
144+
141145
if self.model_config.trust_remote_code:
142146
# note: lazy import to avoid importing torch before initializing
143147
from vllm.utils import init_cached_hf_modules
144148
init_cached_hf_modules()
145-
self.model_runner = CPUModelRunner(model_config,
146-
parallel_config,
147-
scheduler_config,
148-
device_config,
149-
load_config=self.load_config,
150-
lora_config=self.lora_config,
151-
kv_cache_dtype=kv_cache_dtype,
152-
is_driver_worker=is_driver_worker)
149+
self.model_runner = CPUModelRunner(
150+
model_config,
151+
parallel_config,
152+
scheduler_config,
153+
device_config,
154+
load_config=self.load_config,
155+
lora_config=self.lora_config,
156+
vision_language_config=self.vision_language_config,
157+
kv_cache_dtype=kv_cache_dtype,
158+
is_driver_worker=is_driver_worker)
153159
# Uninitialized cache engine. Will be initialized by
154160
# initialize_cache.
155161
self.cache_engine: CPUCacheEngine

0 commit comments

Comments
 (0)