|
1 | 1 | from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
|
2 | 2 |
|
3 | 3 | import torch
|
4 |
| -from PIL import Image |
5 | 4 | from torch import nn
|
6 |
| -from transformers import PaliGemmaConfig, SiglipVisionConfig, SiglipVisionModel |
| 5 | +from transformers import PaliGemmaConfig |
7 | 6 |
|
8 | 7 | from vllm.attention import AttentionMetadata
|
9 | 8 | from vllm.config import CacheConfig, MultiModalConfig
|
|
18 | 17 | from vllm.model_executor.sampling_metadata import SamplingMetadata
|
19 | 18 | from vllm.multimodal import MULTIMODAL_REGISTRY
|
20 | 19 | from vllm.multimodal.image import cached_get_tokenizer
|
21 |
| -from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData |
| 20 | +from vllm.sequence import IntermediateTensors, SamplerOutput |
22 | 21 |
|
23 | 22 | from .interfaces import SupportsVision
|
| 23 | +from .siglip import (SiglipVisionModel, dummy_image_for_siglip, |
| 24 | + dummy_seq_data_for_siglip, get_max_siglip_image_tokens) |
24 | 25 | from .utils import merge_vision_embeddings
|
25 | 26 |
|
26 | 27 | logger = init_logger(__name__)
|
|
32 | 33 |
|
33 | 34 | def get_max_paligemma_image_tokens(ctx: InputContext):
|
34 | 35 | hf_config = ctx.get_hf_config(PaliGemmaConfig)
|
35 |
| - text_config = hf_config.text_config |
36 |
| - |
37 |
| - return text_config.num_image_tokens |
38 |
| - |
39 |
| - |
40 |
| -def dummy_seq_data_for_paligemma( |
41 |
| - hf_config: PaliGemmaConfig, |
42 |
| - seq_len: int, |
43 |
| - *, |
44 |
| - image_token_id: int, |
45 |
| - image_feature_size_override: Optional[int] = None, |
46 |
| -): |
47 |
| - if image_feature_size_override is None: |
48 |
| - image_feature_size = hf_config.text_config.num_image_tokens |
49 |
| - else: |
50 |
| - image_feature_size = image_feature_size_override |
51 |
| - |
52 |
| - token_ids = [image_token_id] * image_feature_size |
53 |
| - token_ids += [0] * (seq_len - image_feature_size) |
54 |
| - return SequenceData(token_ids) |
55 |
| - |
56 |
| - |
57 |
| -def dummy_image_for_paligemma( |
58 |
| - hf_config: SiglipVisionConfig, |
59 |
| - *, |
60 |
| - image_width_override: Optional[int] = None, |
61 |
| - image_height_override: Optional[int] = None, |
62 |
| -): |
63 |
| - width = height = hf_config.image_size |
64 |
| - if image_width_override is not None: |
65 |
| - width = image_width_override |
66 |
| - if image_height_override is not None: |
67 |
| - height = image_height_override |
| 36 | + vision_config = hf_config.vision_config |
68 | 37 |
|
69 |
| - image = Image.new("RGB", (width, height), color=0) |
70 |
| - return {"image": image} |
| 38 | + return get_max_siglip_image_tokens(vision_config) |
71 | 39 |
|
72 | 40 |
|
73 | 41 | def dummy_data_for_paligemma(ctx: InputContext, seq_len: int):
|
74 | 42 | hf_config = ctx.get_hf_config(PaliGemmaConfig)
|
75 | 43 | vision_config = hf_config.vision_config
|
76 | 44 |
|
77 |
| - seq_data = dummy_seq_data_for_paligemma( |
78 |
| - hf_config, |
| 45 | + seq_data = dummy_seq_data_for_siglip( |
| 46 | + vision_config, |
79 | 47 | seq_len,
|
80 | 48 | image_token_id=hf_config.image_token_index,
|
81 | 49 | )
|
82 | 50 |
|
83 |
| - mm_data = dummy_image_for_paligemma(vision_config) |
| 51 | + mm_data = dummy_image_for_siglip(vision_config) |
84 | 52 | return seq_data, mm_data
|
85 | 53 |
|
86 | 54 |
|
@@ -208,30 +176,37 @@ def _parse_and_validate_image_input(
|
208 | 176 | data=self._validate_pixel_values(pixel_values),
|
209 | 177 | )
|
210 | 178 |
|
211 |
| - def _image_pixels_to_features(self, vision_tower: SiglipVisionModel, |
212 |
| - pixel_values: torch.Tensor) -> torch.Tensor: |
| 179 | + def _image_pixels_to_features( |
| 180 | + self, |
| 181 | + vision_tower: SiglipVisionModel, |
| 182 | + pixel_values: torch.Tensor, |
| 183 | + ) -> torch.Tensor: |
213 | 184 |
|
214 | 185 | target_dtype = vision_tower.get_input_embeddings().weight.dtype
|
215 |
| - image_outputs = vision_tower(pixel_values.to(dtype=target_dtype), |
216 |
| - output_hidden_states=True) |
217 |
| - |
218 |
| - selected_image_features = image_outputs.last_hidden_state |
| 186 | + image_features = vision_tower(pixel_values.to(dtype=target_dtype)) |
219 | 187 |
|
220 |
| - return selected_image_features |
| 188 | + return image_features |
221 | 189 |
|
222 | 190 | def _process_image_pixels(
|
223 |
| - self, inputs: PaliGemmaImagePixelInputs) -> torch.Tensor: |
| 191 | + self, |
| 192 | + inputs: PaliGemmaImagePixelInputs, |
| 193 | + ) -> torch.Tensor: |
224 | 194 | assert self.vision_tower is not None
|
225 | 195 |
|
226 | 196 | pixel_values = inputs["data"]
|
227 | 197 |
|
228 |
| - return self._image_pixels_to_features(self.vision_tower, pixel_values) |
| 198 | + return self._image_pixels_to_features( |
| 199 | + self.vision_tower, |
| 200 | + pixel_values, |
| 201 | + ) |
229 | 202 |
|
230 | 203 | def _process_image_input(
|
231 |
| - self, image_input: PaliGemmaImageInputs) -> torch.Tensor: |
| 204 | + self, |
| 205 | + image_input: PaliGemmaImageInputs, |
| 206 | + ) -> torch.Tensor: |
232 | 207 |
|
233 | 208 | assert self.vision_tower is not None
|
234 |
| - image_features = self._process_image_pixels(image_input) |
| 209 | + image_features = self._process_image_pixels(image_input, ) |
235 | 210 |
|
236 | 211 | return self.multi_modal_projector(image_features)
|
237 | 212 |
|
|
0 commit comments