Skip to content

Commit 6e3d1d8

Browse files
[Model] Use merge_by_field_config for MM models (R-T)
Signed-off-by: Ayush Satyam <[email protected]>
1 parent 824a3f4 commit 6e3d1d8

File tree

3 files changed

+12
-9
lines changed

3 files changed

+12
-9
lines changed

vllm/model_executor/models/step3_vl.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
from .utils import (
5050
AutoWeightsLoader,
5151
WeightsMapper,
52-
flatten_bn,
5352
init_vllm_registered_model,
5453
maybe_prefix,
5554
)
@@ -895,6 +894,8 @@ def forward(
895894
dummy_inputs=Step3VLDummyInputsBuilder,
896895
)
897896
class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
897+
merge_by_field_config = True
898+
898899
hf_to_vllm_mapper = WeightsMapper(
899900
orig_to_new_prefix={
900901
"model.": "language_model.model.",
@@ -982,18 +983,21 @@ def _parse_and_validate_image_input(
982983
return None
983984

984985
if pixel_values is not None:
985-
pixel_values = flatten_bn(pixel_values, concat=True)
986986
if pixel_values.dim() >= 3:
987987
pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:])
988988
if patch_pixel_values is not None:
989-
patch_pixel_values = flatten_bn(patch_pixel_values, concat=True)
990989
patch_pixel_values = patch_pixel_values.view(
991990
-1, *patch_pixel_values.shape[-3:]
992991
)
993992
# Handle empty patch_pixel_values by setting to None
994993
if patch_pixel_values.shape[0] == 0:
995994
patch_pixel_values = None
996-
num_patches = flatten_bn(num_patches, concat=True).tolist()
995+
if isinstance(num_patches, torch.Tensor):
996+
num_patches = num_patches.tolist()
997+
elif isinstance(num_patches, list):
998+
num_patches = [
999+
n.item() if isinstance(n, torch.Tensor) else n for n in num_patches
1000+
]
9971001

9981002
return Step3VLImagePixelInputs(
9991003
type="pixel_values",

vllm/model_executor/models/tarsier.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,7 @@
4747
from .clip import CLIPVisionModel
4848
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
4949
from .siglip import SiglipVisionModel
50-
from .utils import (
51-
AutoWeightsLoader,
52-
init_vllm_registered_model,
53-
maybe_prefix,
54-
)
50+
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
5551
from .vision import (
5652
VisionEncoderInfo,
5753
get_num_selected_vision_tokens,
@@ -405,6 +401,8 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers_total: int) ->
405401
class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
406402
merge_by_field_config = True
407403

404+
merge_by_field_config = True
405+
408406
packed_modules_mapping = {
409407
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
410408
"gate_up_proj": ["gate_proj", "up_proj"],

vllm/model_executor/models/terratorch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def apply(
226226
dummy_inputs=TerratorchInputBuilder,
227227
)
228228
class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
229+
merge_by_field_config = True
229230
supports_multimodal_raw_input_only = True
230231
is_pooling_model = True
231232

0 commit comments

Comments
 (0)