Skip to content

Commit b50602d

Browse files
authored
[Model][Gemma3] Cast image pixel values already on CPU (#18732)
Signed-off-by: Lukas Geiger <[email protected]>
1 parent 1f1b1bc commit b50602d

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

vllm/model_executor/models/gemma3_mm.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,11 @@ def _call_hf_processor(
263263
mm_data,
264264
mm_kwargs,
265265
)
266+
if "pixel_values" in processed_outputs:
267+
# Cast pixel values to model dtype already here,
268+
# so we need to transfer less data to the GPU
269+
processed_outputs["pixel_values"] = processed_outputs[
270+
"pixel_values"].to(self.info.ctx.model_config.dtype)
266271

267272
# HF processor pops the `num_crops` kwarg, which is needed by vLLM
268273
if (images := mm_data.get("images")) is not None:
@@ -543,9 +548,7 @@ def _image_pixels_to_features(
543548
vision_tower: SiglipVisionModel,
544549
pixel_values: torch.Tensor,
545550
) -> torch.Tensor:
546-
target_dtype = vision_tower.get_input_embeddings().weight.dtype
547-
image_features = vision_tower(pixel_values.to(dtype=target_dtype))
548-
return image_features
551+
return vision_tower(pixel_values)
549552

550553
def _process_image_input(
551554
self,

0 commit comments

Comments
 (0)