Skip to content

Commit c0d8f16

Browse files
[Model] SiglipVisionModel ported from transformers (#6942)
Co-authored-by: Roger Wang <[email protected]>
1 parent cc08fc7 commit c0d8f16

File tree

3 files changed

+650
-53
lines changed

3 files changed

+650
-53
lines changed

examples/offline_inference_vision_language.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ def run_phi3v(question):
6565
# PaliGemma
6666
def run_paligemma(question):
6767

68-
prompt = question
68+
# PaliGemma has special prompt format for VQA
69+
prompt = "caption en"
6970
llm = LLM(model="google/paligemma-3b-mix-224")
7071

7172
return llm, prompt

vllm/model_executor/models/paligemma.py

Lines changed: 27 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
22

33
import torch
4-
from PIL import Image
54
from torch import nn
6-
from transformers import PaliGemmaConfig, SiglipVisionConfig, SiglipVisionModel
5+
from transformers import PaliGemmaConfig
76

87
from vllm.attention import AttentionMetadata
98
from vllm.config import CacheConfig, MultiModalConfig
@@ -18,9 +17,11 @@
1817
from vllm.model_executor.sampling_metadata import SamplingMetadata
1918
from vllm.multimodal import MULTIMODAL_REGISTRY
2019
from vllm.multimodal.image import cached_get_tokenizer
21-
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
20+
from vllm.sequence import IntermediateTensors, SamplerOutput
2221

2322
from .interfaces import SupportsVision
23+
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
24+
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
2425
from .utils import merge_vision_embeddings
2526

2627
logger = init_logger(__name__)
@@ -32,55 +33,22 @@
3233

3334
def get_max_paligemma_image_tokens(ctx: InputContext):
3435
hf_config = ctx.get_hf_config(PaliGemmaConfig)
35-
text_config = hf_config.text_config
36-
37-
return text_config.num_image_tokens
38-
39-
40-
def dummy_seq_data_for_paligemma(
41-
hf_config: PaliGemmaConfig,
42-
seq_len: int,
43-
*,
44-
image_token_id: int,
45-
image_feature_size_override: Optional[int] = None,
46-
):
47-
if image_feature_size_override is None:
48-
image_feature_size = hf_config.text_config.num_image_tokens
49-
else:
50-
image_feature_size = image_feature_size_override
51-
52-
token_ids = [image_token_id] * image_feature_size
53-
token_ids += [0] * (seq_len - image_feature_size)
54-
return SequenceData(token_ids)
55-
56-
57-
def dummy_image_for_paligemma(
58-
hf_config: SiglipVisionConfig,
59-
*,
60-
image_width_override: Optional[int] = None,
61-
image_height_override: Optional[int] = None,
62-
):
63-
width = height = hf_config.image_size
64-
if image_width_override is not None:
65-
width = image_width_override
66-
if image_height_override is not None:
67-
height = image_height_override
36+
vision_config = hf_config.vision_config
6837

69-
image = Image.new("RGB", (width, height), color=0)
70-
return {"image": image}
38+
return get_max_siglip_image_tokens(vision_config)
7139

7240

7341
def dummy_data_for_paligemma(ctx: InputContext, seq_len: int):
7442
hf_config = ctx.get_hf_config(PaliGemmaConfig)
7543
vision_config = hf_config.vision_config
7644

77-
seq_data = dummy_seq_data_for_paligemma(
78-
hf_config,
45+
seq_data = dummy_seq_data_for_siglip(
46+
vision_config,
7947
seq_len,
8048
image_token_id=hf_config.image_token_index,
8149
)
8250

83-
mm_data = dummy_image_for_paligemma(vision_config)
51+
mm_data = dummy_image_for_siglip(vision_config)
8452
return seq_data, mm_data
8553

8654

@@ -208,30 +176,37 @@ def _parse_and_validate_image_input(
208176
data=self._validate_pixel_values(pixel_values),
209177
)
210178

211-
def _image_pixels_to_features(self, vision_tower: SiglipVisionModel,
212-
pixel_values: torch.Tensor) -> torch.Tensor:
179+
def _image_pixels_to_features(
180+
self,
181+
vision_tower: SiglipVisionModel,
182+
pixel_values: torch.Tensor,
183+
) -> torch.Tensor:
213184

214185
target_dtype = vision_tower.get_input_embeddings().weight.dtype
215-
image_outputs = vision_tower(pixel_values.to(dtype=target_dtype),
216-
output_hidden_states=True)
217-
218-
selected_image_features = image_outputs.last_hidden_state
186+
image_features = vision_tower(pixel_values.to(dtype=target_dtype))
219187

220-
return selected_image_features
188+
return image_features
221189

222190
def _process_image_pixels(
223-
self, inputs: PaliGemmaImagePixelInputs) -> torch.Tensor:
191+
self,
192+
inputs: PaliGemmaImagePixelInputs,
193+
) -> torch.Tensor:
224194
assert self.vision_tower is not None
225195

226196
pixel_values = inputs["data"]
227197

228-
return self._image_pixels_to_features(self.vision_tower, pixel_values)
198+
return self._image_pixels_to_features(
199+
self.vision_tower,
200+
pixel_values,
201+
)
229202

230203
def _process_image_input(
231-
self, image_input: PaliGemmaImageInputs) -> torch.Tensor:
204+
self,
205+
image_input: PaliGemmaImageInputs,
206+
) -> torch.Tensor:
232207

233208
assert self.vision_tower is not None
234-
image_features = self._process_image_pixels(image_input)
209+
image_features = self._process_image_pixels(image_input, )
235210

236211
return self.multi_modal_projector(image_features)
237212

0 commit comments

Comments
 (0)