Skip to content

Commit 59a85c3

Browse files
[Model] Use merge_by_field_config for MM models (H-L) (#26230)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 119f006 commit 59a85c3

File tree

6 files changed

+29
-161
lines changed

6 files changed

+29
-161
lines changed

examples/offline_inference/vision_language_multi_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,7 @@ def load_keye_vl1_5(question: str, image_urls: list[str]) -> ModelRequestData:
548548
engine_args = EngineArgs(
549549
model=model_name,
550550
trust_remote_code=True,
551-
max_model_len=8192,
551+
max_model_len=32768,
552552
max_num_seqs=5,
553553
limit_mm_per_prompt={"image": len(image_urls)},
554554
)

vllm/model_executor/models/idefics3.py

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
# yapf: enable
5454
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
5555
from .llama import LlamaModel
56-
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
56+
from .utils import AutoWeightsLoader, maybe_prefix
5757

5858

5959
class Idefics3ImagePixelInputs(TensorSchema):
@@ -67,7 +67,7 @@ class Idefics3ImagePixelInputs(TensorSchema):
6767
"""
6868
type: Literal["pixel_values"]
6969
pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
70-
pixel_attention_mask: torch.Tensor
70+
pixel_attention_mask: Annotated[torch.Tensor, TensorShape("bnp", "h", "w")]
7171
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
7272

7373

@@ -569,6 +569,8 @@ def forward(
569569
dummy_inputs=Idefics3DummyInputsBuilder)
570570
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
571571
SupportsLoRA):
572+
merge_by_field_config = True
573+
572574
packed_modules_mapping = {
573575
"qkv_proj": [
574576
"q_proj",
@@ -621,37 +623,21 @@ def _parse_and_validate_image_input(
621623
return None
622624

623625
if image_embeds is not None:
624-
if not isinstance(image_embeds, (torch.Tensor, list)):
625-
raise ValueError("Incorrect type of image embeddings. "
626-
f"Got type: {type(image_embeds)}")
627-
628626
return Idefics3ImageEmbeddingInputs(
629627
type="image_embeds",
630-
data=flatten_bn(image_embeds, concat=True),
628+
data=image_embeds,
631629
)
632630

633631
if pixel_values is not None:
634-
if not isinstance(pixel_values, (torch.Tensor, list)):
635-
raise ValueError("Incorrect type of pixel values. "
636-
f"Got type: {type(pixel_values)}")
637-
638632
pixel_attention_mask = kwargs.pop("pixel_attention_mask")
639-
if not isinstance(pixel_attention_mask, (torch.Tensor, list)):
640-
raise ValueError("Incorrect type of pixel_attention_mask. "
641-
f"Got type: {type(pixel_attention_mask)}")
642-
643633
num_patches = kwargs.pop("num_patches")
644-
if not isinstance(num_patches, (torch.Tensor, list)):
645-
raise ValueError("Incorrect type of num_patches. "
646-
f"Got type: {type(num_patches)}")
647-
648634
expected_h = expected_w = self.config.vision_config.image_size
635+
649636
return Idefics3ImagePixelInputs(
650637
type="pixel_values",
651-
pixel_values=flatten_bn(pixel_values, concat=True),
652-
pixel_attention_mask=flatten_bn(pixel_attention_mask,
653-
concat=True),
654-
num_patches=flatten_bn(num_patches, concat=True),
638+
pixel_values=pixel_values,
639+
pixel_attention_mask=pixel_attention_mask,
640+
num_patches=num_patches,
655641
resolve_bindings={
656642
"h": expected_h,
657643
"w": expected_w

vllm/model_executor/models/keye.py

Lines changed: 7 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from vllm.model_executor.model_loader.weight_utils import (
3131
default_weight_loader, maybe_remap_kv_scale_name)
3232
from vllm.model_executor.models.module_mapping import MultiModelKeys
33-
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
33+
from vllm.multimodal import MULTIMODAL_REGISTRY
3434
from vllm.multimodal.inputs import (ImageItem, ModalityData,
3535
MultiModalDataDict, MultiModalFieldConfig,
3636
MultiModalKwargsItems, VideoItem)
@@ -42,7 +42,6 @@
4242
PromptUpdate)
4343
from vllm.multimodal.profiling import BaseDummyInputsBuilder
4444
from vllm.sequence import IntermediateTensors
45-
from vllm.utils import is_list_of
4645
from vllm.utils.tensor_schema import TensorSchema, TensorShape
4746

4847
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
@@ -100,8 +99,7 @@ def smart_resize(
10099
class KeyeImagePixelInputs(TensorSchema):
101100
"""
102101
Dimensions:
103-
- b: Batch size
104-
- np: Number of patches
102+
- bnp: Batch size * Number of patches
105103
- c: Number of channels
106104
- ps: Patch size
107105
- ni: Number of images
@@ -110,7 +108,7 @@ class KeyeImagePixelInputs(TensorSchema):
110108
type: Literal["pixel_values"]
111109
pixel_values: Annotated[
112110
torch.Tensor,
113-
TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})]
111+
TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})]
114112
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
115113

116114

@@ -134,8 +132,7 @@ class KeyeImageEmbeddingInputs(TensorSchema):
134132
class KeyeVideoPixelInputs(TensorSchema):
135133
"""
136134
Dimensions:
137-
- b: Batch size
138-
- np: Number of patches
135+
- bnp: Batch size * Number of patches
139136
- c: Number of channels
140137
- ps: Patch size
141138
- ni: Number of images
@@ -144,7 +141,7 @@ class KeyeVideoPixelInputs(TensorSchema):
144141
type: Literal["pixel_values_videos"]
145142
pixel_values_videos: Annotated[
146143
torch.Tensor,
147-
TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})]
144+
TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})]
148145
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
149146

150147

@@ -1258,6 +1255,8 @@ def _get_mm_fields_config(
12581255

12591256

12601257
class BaseKeyeModule(nn.Module):
1258+
merge_by_field_config = True
1259+
12611260
packed_modules_mapping = {
12621261
"qkv_proj": [
12631262
"q_proj",
@@ -1524,28 +1523,6 @@ def _build_projector(self,
15241523
prefix: str = "") -> nn.Module:
15251524
return Projector(text_config, vision_config, quant_config, prefix)
15261525

1527-
def _validate_and_reshape_mm_tensor(
1528-
self, mm_input: NestedTensors,
1529-
name: str) -> Union[torch.Tensor, list[torch.Tensor]]:
1530-
if not isinstance(mm_input, (torch.Tensor, list)):
1531-
raise ValueError(f"Incorrect type of {name}. "
1532-
f"Got type: {type(mm_input)}")
1533-
if isinstance(mm_input, torch.Tensor):
1534-
if mm_input.ndim == 2:
1535-
return mm_input
1536-
if mm_input.ndim == 5:
1537-
return mm_input
1538-
if mm_input.ndim != 3:
1539-
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
1540-
f"Got ndim: {mm_input.ndim} "
1541-
f"(shape={mm_input.shape})")
1542-
return mm_input.reshape(-1, mm_input.shape[-1])
1543-
elif is_list_of(mm_input, torch.Tensor):
1544-
if all(p.dim() == 4 for p in mm_input) or all(p.dim() == 2
1545-
for p in mm_input):
1546-
return mm_input
1547-
return torch.concat(mm_input)
1548-
15491526
def _parse_and_validate_image_input(
15501527
self, **kwargs: object) -> Optional[KeyeImageInputs]:
15511528
pixel_values = kwargs.pop("pixel_values", None)
@@ -1556,23 +1533,13 @@ def _parse_and_validate_image_input(
15561533
return None
15571534

15581535
if pixel_values is not None:
1559-
pixel_values = self._validate_and_reshape_mm_tensor(
1560-
pixel_values, "image pixel values")
1561-
image_grid_thw = self._validate_and_reshape_mm_tensor(
1562-
image_grid_thw, "image grid_thw")
1563-
15641536
return KeyeImagePixelInputs(
15651537
type="pixel_values",
15661538
pixel_values=pixel_values,
15671539
image_grid_thw=image_grid_thw,
15681540
)
15691541

15701542
if image_embeds is not None:
1571-
image_embeds = self._validate_and_reshape_mm_tensor(
1572-
image_embeds, "image embeds")
1573-
image_grid_thw = self._validate_and_reshape_mm_tensor(
1574-
image_grid_thw, "image grid_thw")
1575-
15761543
return KeyeImageEmbeddingInputs(
15771544
type="image_embeds",
15781545
image_embeds=image_embeds,
@@ -1589,25 +1556,13 @@ def _parse_and_validate_video_input(
15891556
return None
15901557

15911558
if pixel_values_videos is not None:
1592-
pixel_values_videos = self._validate_and_reshape_mm_tensor(
1593-
pixel_values_videos,
1594-
"video pixel values",
1595-
)
1596-
video_grid_thw = self._validate_and_reshape_mm_tensor(
1597-
video_grid_thw, "video grid_thw")
1598-
15991559
return KeyeVideoPixelInputs(
16001560
type="pixel_values_videos",
16011561
pixel_values_videos=pixel_values_videos,
16021562
video_grid_thw=video_grid_thw,
16031563
)
16041564

16051565
if video_embeds is not None:
1606-
video_embeds = self._validate_and_reshape_mm_tensor(
1607-
video_embeds, "video embeds")
1608-
video_grid_thw = self._validate_and_reshape_mm_tensor(
1609-
video_grid_thw, "video grid_thw")
1610-
16111566
return KeyeVideoEmbeddingInputs(
16121567
type="video_embeds",
16131568
video_embeds=video_embeds,

vllm/model_executor/models/keye_vl1_5.py

Lines changed: 5 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
1919
RowParallelLinear)
2020
from vllm.model_executor.layers.quantization import QuantizationConfig
21-
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
21+
from vllm.multimodal import MULTIMODAL_REGISTRY
2222
from vllm.multimodal.inputs import (ImageItem, ModalityData,
2323
MultiModalFieldConfig,
2424
MultiModalKwargsItems, VideoItem)
@@ -100,8 +100,7 @@ def get_num_patches(grid_thw: torch.Tensor,
100100
class KeyeVL1_5ImagePixelInputs(TensorSchema):
101101
"""
102102
Dimensions:
103-
- b: Batch size
104-
- np: Number of patches
103+
- bnp: Batch size * Number of patches
105104
- c: Number of channels
106105
- ps: Patch size
107106
- ni: Number of images
@@ -111,7 +110,7 @@ class KeyeVL1_5ImagePixelInputs(TensorSchema):
111110

112111
pixel_values: Annotated[
113112
torch.Tensor,
114-
TensorShape("np", 3, "ps", "ps", dynamic_dims={"np"})]
113+
TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})]
115114

116115
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
117116

@@ -137,8 +136,7 @@ class KeyeVL1_5ImageEmbeddingInputs(TensorSchema):
137136
class KeyeVL1_5VideoPixelInputs(TensorSchema):
138137
"""
139138
Dimensions:
140-
- b: Batch size
141-
- np: Number of patches
139+
- bnp: Batch size * Number of patches
142140
- c: Number of channels
143141
- ps: Patch size
144142
- ni: Number of images
@@ -147,7 +145,7 @@ class KeyeVL1_5VideoPixelInputs(TensorSchema):
147145
type: Literal["pixel_values_videos"]
148146
pixel_values_videos: Annotated[
149147
torch.Tensor,
150-
TensorShape("np", 3, "ps", "ps", dynamic_dims={"np"})]
148+
TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})]
151149
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
152150

153151
num_frames: torch.Tensor
@@ -483,24 +481,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
483481
self.merge_size = config.vision_config.spatial_merge_size
484482
super().__init__(vllm_config=vllm_config, prefix=prefix)
485483

486-
def _validate_and_reshape_mm_tensor(self, mm_input: NestedTensors,
487-
expected_dim: int, name: str):
488-
if not isinstance(mm_input, (torch.Tensor, list)):
489-
raise ValueError(f"Incorrect type of {name}. "
490-
f"Got type: {type(mm_input)}")
491-
if isinstance(mm_input, torch.Tensor):
492-
if mm_input.ndim == expected_dim:
493-
return mm_input
494-
elif mm_input.ndim == expected_dim + 1:
495-
return mm_input.reshape(-1, *mm_input.shape[2:])
496-
else:
497-
raise ValueError(
498-
f"{name} should be {expected_dim}D or "
499-
f"batched {expected_dim}D tensor."
500-
f"Got ndim: {mm_input.ndim} (shape={mm_input.shape})")
501-
else:
502-
return torch.concat(mm_input)
503-
504484
def _parse_and_validate_image_input(
505485
self, **kwargs: object) -> Optional[KeyeVL1_5ImageInputs]:
506486
pixel_values = kwargs.pop("pixel_values", None)
@@ -511,23 +491,13 @@ def _parse_and_validate_image_input(
511491
return None
512492

513493
if pixel_values is not None:
514-
pixel_values = self._validate_and_reshape_mm_tensor(
515-
pixel_values, expected_dim=4, name="image pixel values")
516-
image_grid_thw = self._validate_and_reshape_mm_tensor(
517-
image_grid_thw, expected_dim=2, name="image grid_thw")
518-
519494
return KeyeVL1_5ImagePixelInputs(
520495
type="pixel_values",
521496
pixel_values=pixel_values,
522497
image_grid_thw=image_grid_thw,
523498
)
524499

525500
if image_embeds is not None:
526-
image_embeds = self._validate_and_reshape_mm_tensor(
527-
image_embeds, expected_dim=2, name="image embeds")
528-
image_grid_thw = self._validate_and_reshape_mm_tensor(
529-
image_grid_thw, expected_dim=2, name="image grid_thw")
530-
531501
return KeyeVL1_5ImageEmbeddingInputs(
532502
type="image_embeds",
533503
image_embeds=image_embeds,
@@ -545,29 +515,13 @@ def _parse_and_validate_video_input(
545515
return None
546516

547517
if pixel_values_videos is not None:
548-
pixel_values_videos = self._validate_and_reshape_mm_tensor(
549-
pixel_values_videos,
550-
expected_dim=4,
551-
name="video pixel values",
552-
)
553-
video_grid_thw = self._validate_and_reshape_mm_tensor(
554-
video_grid_thw, expected_dim=2, name="video grid_thw")
555-
556-
num_frames = self._validate_and_reshape_mm_tensor(
557-
num_frames, expected_dim=1, name="video num frames")
558-
559518
return KeyeVL1_5VideoPixelInputs(
560519
type="pixel_values_videos",
561520
pixel_values_videos=pixel_values_videos,
562521
video_grid_thw=video_grid_thw,
563522
num_frames=num_frames)
564523

565524
if video_embeds is not None:
566-
video_embeds = self._validate_and_reshape_mm_tensor(
567-
video_embeds, expected_dim=2, name="video embeds")
568-
video_grid_thw = self._validate_and_reshape_mm_tensor(
569-
video_grid_thw, expected_dim=2, name="video grid_thw")
570-
571525
return KeyeVL1_5VideoEmbeddingInputs(type="video_embeds",
572526
video_embeds=video_embeds,
573527
video_grid_thw=video_grid_thw,

0 commit comments

Comments
 (0)