Skip to content

Commit 8323233

Browse files
DarkLight1337zhewenl
authored andcommitted
[Model] Merge SupportsMultiModalWithRawInput with SupportsMultiModal (vllm-project#23749)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent afa3e06 commit 8323233

File tree

5 files changed

+30
-50
lines changed

5 files changed

+30
-50
lines changed

vllm/config/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1698,6 +1698,10 @@ def uses_mrope(self) -> bool:
16981698
def is_multimodal_model(self) -> bool:
16991699
return self.multimodal_config is not None
17001700

1701+
@property
1702+
def is_multimodal_raw_input_only_model(self) -> bool:
1703+
return self._model_info.supports_multimodal_raw_input_only
1704+
17011705
@property
17021706
def is_cross_encoder(self) -> bool:
17031707
return (self._model_info.supports_cross_encoding
@@ -1707,10 +1711,6 @@ def is_cross_encoder(self) -> bool:
17071711
def is_pp_supported(self) -> bool:
17081712
return self._model_info.supports_pp
17091713

1710-
@property
1711-
def is_multimodal_raw_input_supported(self) -> bool:
1712-
return self._model_info.supports_multimodal_raw_input
1713-
17141714
@property
17151715
def is_attention_free(self) -> bool:
17161716
return self._model_info.is_attention_free

vllm/model_executor/models/interfaces.py

Lines changed: 11 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ class SupportsMultiModal(Protocol):
5252
MRO of your model class.
5353
"""
5454

55+
supports_multimodal_raw_input_only: ClassVar[bool] = False
56+
"""
57+
A flag that indicates this model supports multi-modal inputs and processes
58+
them in their raw form and not embeddings.
59+
"""
60+
5561
supports_encoder_tp_data: ClassVar[bool] = False
5662
"""
5763
A flag that indicates whether this model supports
@@ -143,43 +149,14 @@ def supports_multimodal(
143149
return getattr(model, "supports_multimodal", False)
144150

145151

146-
def supports_multimodal_encoder_tp_data(
152+
def supports_multimodal_raw_input_only(
147153
model: Union[type[object], object]) -> bool:
148-
return getattr(model, "supports_encoder_tp_data", False)
149-
150-
151-
@runtime_checkable
152-
class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol):
153-
"""The interface required for all multi-modal models."""
154-
155-
supports_multimodal_raw_input: ClassVar[Literal[True]] = True
156-
"""
157-
A flag that indicates this model supports multi-modal inputs and processes
158-
them in their raw form and not embeddings.
159-
160-
Note:
161-
There is no need to redefine this flag if this class is in the
162-
MRO of your model class.
163-
"""
164-
165-
166-
@overload
167-
def supports_multimodal_raw_input(
168-
model: object) -> TypeIs[SupportsMultiModalWithRawInput]:
169-
...
154+
return getattr(model, "supports_multimodal_raw_input_only", False)
170155

171156

172-
@overload
173-
def supports_multimodal_raw_input(
174-
model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]:
175-
...
176-
177-
178-
def supports_multimodal_raw_input(
179-
model: Union[type[object], object]
180-
) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]],
181-
TypeIs[SupportsMultiModalWithRawInput]]:
182-
return getattr(model, "supports_multimodal_raw_input", False)
157+
def supports_multimodal_encoder_tp_data(
158+
model: Union[type[object], object]) -> bool:
159+
return getattr(model, "supports_encoder_tp_data", False)
183160

184161

185162
@runtime_checkable

vllm/model_executor/models/prithvi_geospatial_mae.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from vllm.sequence import IntermediateTensors
4242

4343
from .interfaces import (IsAttentionFree, MultiModalEmbeddings,
44-
SupportsMultiModalWithRawInput)
44+
SupportsMultiModal)
4545
from .interfaces_base import default_pooling_type
4646

4747

@@ -174,10 +174,10 @@ def apply(
174174
info=PrithviGeoSpatialMAEProcessingInfo,
175175
dummy_inputs=PrithviGeoSpatialMAEInputBuilder,
176176
)
177-
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree,
178-
SupportsMultiModalWithRawInput):
177+
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
179178
"""Prithvi Masked Autoencoder"""
180179

180+
supports_multimodal_raw_input_only = True
181181
is_pooling_model = True
182182

183183
@classmethod

vllm/model_executor/models/registry.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
is_hybrid, supports_cross_encoding,
3030
supports_multimodal,
3131
supports_multimodal_encoder_tp_data,
32-
supports_multimodal_raw_input, supports_pp,
32+
supports_multimodal_raw_input_only, supports_pp,
3333
supports_transcription, supports_v0_only)
3434
from .interfaces_base import (get_default_pooling_type, is_pooling_model,
3535
is_text_generation_model)
@@ -326,7 +326,7 @@ class _ModelInfo:
326326
default_pooling_type: str
327327
supports_cross_encoding: bool
328328
supports_multimodal: bool
329-
supports_multimodal_raw_input: bool
329+
supports_multimodal_raw_input_only: bool
330330
supports_multimodal_encoder_tp_data: bool
331331
supports_pp: bool
332332
has_inner_state: bool
@@ -346,7 +346,8 @@ def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
346346
default_pooling_type=get_default_pooling_type(model),
347347
supports_cross_encoding=supports_cross_encoding(model),
348348
supports_multimodal=supports_multimodal(model),
349-
supports_multimodal_raw_input=supports_multimodal_raw_input(model),
349+
supports_multimodal_raw_input_only=
350+
supports_multimodal_raw_input_only(model),
350351
supports_multimodal_encoder_tp_data=
351352
supports_multimodal_encoder_tp_data(model),
352353
supports_pp=supports_pp(model),
@@ -743,13 +744,13 @@ def is_multimodal_model(
743744
model_cls, _ = self.inspect_model_cls(architectures, model_config)
744745
return model_cls.supports_multimodal
745746

746-
def supports_multimodal_raw_input(
747+
def is_multimodal_raw_input_only_model(
747748
self,
748749
architectures: Union[str, list[str]],
749750
model_config: ModelConfig,
750751
) -> bool:
751752
model_cls, _ = self.inspect_model_cls(architectures, model_config)
752-
return model_cls.supports_multimodal_raw_input
753+
return model_cls.supports_multimodal_raw_input_only
753754

754755
def is_pp_supported_model(
755756
self,

vllm/v1/worker/gpu_model_runner.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,9 @@ def __init__(
139139
cache_config.cache_dtype]
140140

141141
self.is_pooling_model = model_config.pooler_config is not None
142-
self.is_multimodal_raw_input_supported = (
143-
model_config.is_multimodal_raw_input_supported)
142+
self.is_multimodal_raw_input_only_model = (
143+
model_config.is_multimodal_raw_input_only_model)
144+
144145
self.max_model_len = model_config.max_model_len
145146
self.max_num_tokens = scheduler_config.max_num_batched_tokens
146147
self.max_num_reqs = scheduler_config.max_num_seqs
@@ -612,7 +613,7 @@ def _extract_mm_kwargs(
612613
self,
613614
scheduler_output: "SchedulerOutput",
614615
) -> BatchedTensorInputs:
615-
if not self.is_multimodal_raw_input_supported or not scheduler_output: # noqa: SIM102
616+
if not scheduler_output or not self.is_multimodal_raw_input_only_model:
616617
return {}
617618

618619
mm_kwargs = list[MultiModalKwargsItem]()
@@ -631,8 +632,9 @@ def _extract_mm_kwargs(
631632
return mm_kwargs_combined
632633

633634
def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs:
634-
if not self.is_multimodal_raw_input_supported:
635+
if not self.is_multimodal_raw_input_only_model:
635636
return {}
637+
636638
mm_budget = self.mm_budget
637639
assert mm_budget is not None
638640

0 commit comments

Comments
 (0)