Skip to content

Commit ae96ef8

Browse files
[VLM] Calculate maximum number of multi-modal tokens by model (#6121)
1 parent 69ec3ca commit ae96ef8

File tree

12 files changed

+260
-90
lines changed

12 files changed

+260
-90
lines changed

docs/source/dev/multimodal/adding_multimodal_model.rst

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -51,40 +51,68 @@ As usual, follow :ref:`these steps <adding_a_new_model>` to implement the model
5151
2. Register input mappers
5252
-------------------------
5353

54-
For each modality type to support, decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_input_mapper <vllm.multimodal.MultiModalRegistry.register_input_mapper>`.
54+
For each modality type that the model accepts as input, decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_input_mapper <vllm.multimodal.MultiModalRegistry.register_input_mapper>`.
5555
This decorator accepts a function that maps multi-modal inputs to the keyword arguments you have previously defined in :meth:`~torch.nn.Module.forward`.
5656

5757
.. code-block:: diff
5858
59-
from vllm.model_executor.models.interfaces import SupportsVision
59+
from vllm.model_executor.models.interfaces import SupportsVision
6060
+ from vllm.multimodal import MULTIMODAL_REGISTRY
6161
62-
+ @MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
63-
+ @MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
64-
class YourModelForImage2Seq(nn.Module, SupportsVision):
62+
+ @MULTIMODAL_REGISTRY.register_image_input_mapper()
63+
class YourModelForImage2Seq(nn.Module, SupportsVision):
6564
6665
A default mapper is available for each modality in the core vLLM library. This input mapper will be used if you do not provide your own function.
6766

6867
.. seealso::
6968
:ref:`input_processing_pipeline`
7069

7170

72-
3. (Optional) Register dummy data
71+
3. Register maximum number of multimodal tokens
72+
----------------------------------------------------------
73+
74+
For each modality type that the model accepts as input, calculate the maximum possible number of tokens
75+
and register it via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_max_multimodal_tokens>`.
76+
77+
.. code-block:: diff
78+
79+
from vllm.inputs import INPUT_REGISTRY
80+
from vllm.model_executor.models.interfaces import SupportsVision
81+
from vllm.multimodal import MULTIMODAL_REGISTRY
82+
83+
@MULTIMODAL_REGISTRY.register_image_input_mapper()
84+
+ @MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
85+
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
86+
class YourModelForImage2Seq(nn.Module, SupportsVision):
87+
88+
Here are some examples:
89+
90+
- Image inputs (static feature size): `LLaVA-1.5 Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava.py>`__
91+
- Image inputs (dynamic feature size): `LLaVA-NeXT Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava_next.py>`__
92+
93+
.. seealso::
94+
:ref:`input_processing_pipeline`
95+
96+
97+
4. (Optional) Register dummy data
7398
---------------------------------
7499

75100
During startup, dummy data is passed to the vLLM model to allocate memory. This only consists of text input by default, which may not be applicable to multi-modal models.
76101
In such cases, you can define your own dummy data by registering a factory method via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_dummy_data>`.
77102

78103
.. code-block:: diff
79104
80-
from vllm.inputs import INPUT_REGISTRY
81-
from vllm.model_executor.models.interfaces import SupportsVision
82-
from vllm.multimodal import MULTIMODAL_REGISTRY
105+
from vllm.inputs import INPUT_REGISTRY
106+
from vllm.model_executor.models.interfaces import SupportsVision
107+
from vllm.multimodal import MULTIMODAL_REGISTRY
83108
84-
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
85-
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
109+
@MULTIMODAL_REGISTRY.register_image_input_mapper()
110+
@MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
86111
+ @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
87-
class YourModelForImage2Seq(nn.Module, SupportsVision):
112+
class YourModelForImage2Seq(nn.Module, SupportsVision):
113+
114+
.. note::
115+
The dummy data should have the maximum possible number of multi-modal tokens, as described in the previous step.
88116

89117
Here are some examples:
90118

@@ -95,7 +123,7 @@ Here are some examples:
95123
:ref:`input_processing_pipeline`
96124

97125

98-
4. (Optional) Register input processor
126+
5. (Optional) Register input processor
99127
--------------------------------------
100128

101129
Sometimes, there is a need to process inputs at the :class:`~vllm.LLMEngine` level before they are passed to the model executor.
@@ -104,15 +132,15 @@ You can register input processors via :meth:`INPUT_REGISTRY.register_input_proce
104132

105133
.. code-block:: diff
106134
107-
from vllm.inputs import INPUT_REGISTRY
108-
from vllm.model_executor.models.interfaces import SupportsVision
109-
from vllm.multimodal import MULTIMODAL_REGISTRY
135+
from vllm.inputs import INPUT_REGISTRY
136+
from vllm.model_executor.models.interfaces import SupportsVision
137+
from vllm.multimodal import MULTIMODAL_REGISTRY
110138
111-
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
112-
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
113-
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
139+
@MULTIMODAL_REGISTRY.register_image_input_mapper()
140+
@MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
141+
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
114142
+ @INPUT_REGISTRY.register_input_processor(<your_input_processor>)
115-
class YourModelForImage2Seq(nn.Module, SupportsVision):
143+
class YourModelForImage2Seq(nn.Module, SupportsVision):
116144
117145
A common use case of input processors is inserting placeholder tokens to leverage the vLLM framework for attention mask generation.
118146
Here are some examples:

docs/source/models/vlm.rst

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,8 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM``
2525
2626
.. important::
2727
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
28-
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified, and internally we will construct data structures for
29-
every model to perform profiling with.
30-
31-
This work is still ongoing. In the meantime, we internally hardcode ``image_feature_size = 3000`` through
32-
:meth:`MULTIMODAL_REGISTRY.get_num_input_tokens <vllm.multimodal.MultiModalRegistry.get_num_input_tokens>`
33-
for every model to be conservative in terms of GPU memory consumption. This hardcoded value will be replaced
34-
with a more accurate profiling strategy in the future.
28+
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified as we now calculate that
29+
internally for each model.
3530

3631

3732
To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`:
@@ -104,13 +99,8 @@ Below is an example on how to launch the same ``llava-hf/llava-1.5-7b-hf`` with
10499
105100
.. important::
106101
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
107-
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified, and internally we will construct data structures for
108-
every model to perform profiling with.
109-
110-
This work is still ongoing. In the meantime, we internally hardcode ``image_feature_size = 3000`` through
111-
:meth:`MULTIMODAL_REGISTRY.get_num_input_tokens <vllm.multimodal.MultiModalRegistry.get_num_input_tokens>`
112-
for every model to be conservative in terms of GPU memory consumption. This hardcoded value will be replaced
113-
with a more accurate profiling strategy in the future.
102+
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified as we now calculate that
103+
internally for each model.
114104

115105
To consume the server, you can use the OpenAI client like in the example below:
116106

vllm/inputs/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def get_hf_config(self, hf_config_type: Type[C]) -> C:
5151
additionally checking its type.
5252
5353
Raises:
54-
ValueError: If the model is not of the specified type.
54+
TypeError: If the model is not of the specified type.
5555
"""
5656

5757
hf_config = self.model_config.hf_config

vllm/model_executor/models/clip.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
3535
patch_size=hf_config.patch_size)
3636

3737

38+
def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
39+
return get_clip_image_feature_size(hf_config)
40+
41+
3842
def dummy_seq_data_for_clip(
3943
hf_config: CLIPVisionConfig,
4044
seq_len: int,

vllm/model_executor/models/llava.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from vllm.sequence import IntermediateTensors, SamplerOutput
2222

2323
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
24-
input_processor_for_clip)
24+
get_max_clip_image_tokens, input_processor_for_clip)
2525
from .interfaces import SupportsVision
2626
from .utils import merge_vision_embeddings
2727

@@ -62,6 +62,17 @@ class LlavaImagePixelInputs(TypedDict):
6262
LlavaImageInputs = LlavaImagePixelInputs
6363

6464

65+
def get_max_llava_image_tokens(ctx: InputContext):
66+
hf_config = ctx.get_hf_config(LlavaConfig)
67+
vision_config = hf_config.vision_config
68+
69+
if isinstance(vision_config, CLIPVisionConfig):
70+
return get_max_clip_image_tokens(vision_config)
71+
72+
msg = f"Unsupported vision config: {type(vision_config)}"
73+
raise NotImplementedError(msg)
74+
75+
6576
def dummy_data_for_llava(ctx: InputContext, seq_len: int):
6677
hf_config = ctx.get_hf_config(LlavaConfig)
6778
vision_config = hf_config.vision_config
@@ -102,6 +113,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
102113

103114

104115
@MULTIMODAL_REGISTRY.register_image_input_mapper()
116+
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
105117
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
106118
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
107119
class LlavaForConditionalGeneration(nn.Module, SupportsVision):

vllm/model_executor/models/llava_next.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,17 @@ def get_llava_next_image_feature_size(
127127
raise NotImplementedError(msg)
128128

129129

130+
def get_max_llava_next_image_tokens(ctx: InputContext):
131+
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
132+
dummy_height = dummy_width = 448
133+
134+
return get_llava_next_image_feature_size(
135+
ctx.get_hf_config(LlavaNextConfig),
136+
input_height=dummy_height,
137+
input_width=dummy_width,
138+
)
139+
140+
130141
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
131142
hf_config = ctx.get_hf_config(LlavaNextConfig)
132143
vision_config = hf_config.vision_config
@@ -198,6 +209,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
198209

199210

200211
@MULTIMODAL_REGISTRY.register_image_input_mapper()
212+
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
201213
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
202214
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
203215
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):

vllm/model_executor/models/phi3v.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,17 @@ def get_phi3v_image_feature_size(
321321
+ (new_height // 336 + 1) * 12
322322

323323

324+
def get_max_phi3v_image_tokens(ctx: InputContext):
325+
# Result in the max possible feature size (h:w = 16:1)
326+
dummy_height, dummy_width = 8000, 50
327+
328+
return get_phi3v_image_feature_size(
329+
ctx.get_hf_config(PretrainedConfig),
330+
input_height=dummy_height,
331+
input_width=dummy_width,
332+
)
333+
334+
324335
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
325336
# Result in the max possible feature size (h:w = 16:1)
326337
dummy_height, dummy_width = 8000, 50
@@ -429,6 +440,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
429440

430441

431442
@MULTIMODAL_REGISTRY.register_image_input_mapper()
443+
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
432444
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
433445
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
434446
class Phi3VForCausalLM(nn.Module, SupportsVision):

0 commit comments

Comments
 (0)