Skip to content

Commit e4f6360

Browse files
authored
Merge branch 'main' into fix-gpt-oss-with-speculative-decoding-handle-multiple-channels
2 parents e1f14dd + ab5e7d9 commit e4f6360

File tree

2 files changed

+38
-68
lines changed

2 files changed

+38
-68
lines changed

tests/models/multimodal/generation/test_common.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,7 @@
222222
vllm_runner_kwargs={
223223
"model_impl": "transformers",
224224
},
225-
# FIXME: Investigate mrope issue
226-
marks=[large_gpu_mark(min_gb=32), pytest.mark.skip(reason="Mrope issue")],
225+
marks=[large_gpu_mark(min_gb=32)],
227226
),
228227
#### Extended model tests
229228
"aria": VLMTestInfo(

vllm/model_executor/models/transformers.py

Lines changed: 37 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@
7979
AutoWeightsLoader,
8080
PPMissingLayer,
8181
WeightsMapper,
82-
flatten_bn,
8382
make_empty_intermediate_tensors_factory,
8483
maybe_prefix,
8584
)
@@ -347,54 +346,37 @@ def _get_prompt_updates(
347346

348347
def _get_mm_fields_config(
349348
self,
350-
hf_inputs,
351-
hf_processor_mm_kwargs,
352-
num_image_patches: torch.Tensor = None,
353-
):
349+
hf_inputs: BatchFeature,
350+
hf_processor_mm_kwargs: Mapping[str, object],
351+
) -> Mapping[str, MultiModalFieldConfig]:
354352
# HF Processors always return a mask but vLLM doesn't need it
355353
hf_inputs.pop("attention_mask", None)
354+
num_image_patches = hf_inputs.get("num_image_patches")
356355
mm_fields = {
357356
key: MultiModalFieldConfig.flat_from_sizes("image", num_image_patches)
358357
for key in hf_inputs
359358
}
360359
mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes(
361360
"image", num_image_patches
362361
)
362+
363+
# Keep these as batched, as they always have batch size as first dim
364+
mm_fields["image_grid_thw"] = MultiModalFieldConfig.batched("image")
365+
mm_fields["video_grid_thw"] = MultiModalFieldConfig.batched("image")
363366
mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image")
364367
return mm_fields
365368

366-
def _apply_hf_processor_text_mm(
369+
def _get_hf_mm_data(
367370
self,
368-
prompt_text: str,
369371
mm_items: MultiModalDataItems,
370-
hf_processor_mm_kwargs: Mapping[str, object],
371-
tokenization_kwargs: Mapping[str, object],
372-
) -> tuple[list[int], BatchFeature, bool]:
372+
) -> tuple[Mapping[str, object], Mapping[str, object]]:
373373
"""
374-
Apply the HF processor on the prompt text and multi-modal data
375-
together.
376-
377-
In addition, return whether prompt replacements have been applied.
374+
In contrast to the base class, this method always adds
375+
`return_mm_token_type_ids` to the processor data
378376
"""
379-
processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
377+
processor_data, passthrough_data = super()._get_hf_mm_data(mm_items)
380378
processor_data["return_mm_token_type_ids"] = True
381-
382-
processed_data = self._call_hf_processor(
383-
prompt=prompt_text,
384-
mm_data=processor_data,
385-
mm_kwargs=hf_processor_mm_kwargs,
386-
tok_kwargs=tokenization_kwargs,
387-
)
388-
processed_data.update(passthrough_data)
389-
390-
(prompt_ids,) = processed_data.pop("input_ids").tolist()
391-
mm_token_type_ids = (
392-
processed_data.pop("mm_token_type_ids")
393-
if "mm_token_type_ids" in processed_data
394-
else processed_data.pop("token_type_ids")
395-
) # for gemma3 only
396-
397-
return prompt_ids, processed_data, mm_token_type_ids
379+
return processor_data, passthrough_data
398380

399381
def apply(
400382
self,
@@ -421,18 +403,28 @@ def apply(
421403
# into string
422404
prompt = hf_processor.decode(prompt)
423405

424-
(prompt_ids, processed_data, mm_token_type_ids) = (
425-
self._apply_hf_processor_text_mm(
426-
prompt_text=prompt,
427-
mm_items=mm_items,
428-
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
429-
tokenization_kwargs=tokenization_kwargs,
430-
)
406+
# Bypass cached processor and always apply to the full set of mm inputs
407+
# NOTE: we can't just set caching=False because base class method
408+
# transforms outputs to `MultiModalKwargs` which is not going to
409+
# work for Transformers. We have a lot of logic tied to
410+
# `mm_tokens_per_modality` below
411+
prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm(
412+
prompt_text=prompt,
413+
mm_items=mm_items,
414+
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
415+
tokenization_kwargs=tokenization_kwargs,
431416
)
432417

433-
# HF processor will return `mm_token_type_ids` from which
434-
# we can infer mm_placeholders. Until then hardcode to make code run
435-
# Below tested on Llava. Prompts and `mm_token_type_ids` are always bs=1
418+
# For gemma3 we check `token_type_ids` as the key
419+
token_type_key = (
420+
"mm_token_type_ids"
421+
if "mm_token_type_ids" in processed_data
422+
else "token_type_ids"
423+
)
424+
mm_token_type_ids = processed_data.pop(token_type_key)
425+
426+
# We can infer vLLM style placeholder from token type ids, if we split
427+
# it for each input `mm_data`.
436428
mm_positions = torch.where(mm_token_type_ids == 1)[1]
437429
images = mm_items.get_items("image", ImageProcessorItems)
438430
multimodal_config = self.info.ctx.model_config.multimodal_config
@@ -462,17 +454,12 @@ def apply(
462454
]
463455
mm_placeholders = {"image": ranges}
464456

465-
num_image_patches = (
466-
torch.tensor(mm_tokens_per_modality["num_image_patches"])
467-
if "num_image_patches" in mm_tokens_per_modality
468-
else None
457+
processed_data["num_image_patches"] = torch.tensor(
458+
mm_tokens_per_modality["num_image_patches"]
469459
)
470-
processed_data["num_image_patches"] = num_image_patches
471460
mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
472461
processed_data,
473-
self._get_mm_fields_config(
474-
processed_data, hf_processor_mm_kwargs, num_image_patches
475-
),
462+
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
476463
)
477464

478465
# Use overrides if provided; fallback to data-dependent hashing.
@@ -531,8 +518,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
531518
self.ignore_unexpected_suffixes.append(".bias")
532519

533520
# Set correct attn and init on "meta" to delay allocating GPU tensors
534-
# TODO: @raushan, use the public `model.set_attn_implementation()`
535-
# method once its checks are fixed in Transformers.
536521
self.text_config._attn_implementation = "vllm"
537522
with init_on_device_without_buffers("meta"):
538523
self.model: PreTrainedModel = AutoModel.from_config(
@@ -844,17 +829,6 @@ def compute_logits(
844829
return logits
845830

846831

847-
def flatten_and_concat(x: list[torch.Tensor]) -> torch.Tensor:
848-
"""Flatten until a list of tensors can be concatenated then do concat"""
849-
850-
def _can_concat(x: list[torch.Tensor]):
851-
return len(set(map(lambda _x: _x.shape[1:], x))) == 1
852-
853-
if _can_concat(x):
854-
return torch.concat(x)
855-
return flatten_and_concat(flatten_bn(x))
856-
857-
858832
@MULTIMODAL_REGISTRY.register_processor(
859833
MultiModalProcessor,
860834
info=MultiModalProcessingInfo,
@@ -935,9 +909,6 @@ def get_multimodal_embeddings(self, **kwargs):
935909
vision_embeddings = self.model.get_image_features(pixel_values, **kwargs)
936910

937911
if isinstance(vision_embeddings, torch.Tensor):
938-
if isinstance(num_image_patches, list):
939-
num_image_patches = torch.cat(num_image_patches)
940-
941912
if vision_embeddings.ndim == 2:
942913
vision_embeddings = vision_embeddings.unsqueeze(0)
943914

0 commit comments

Comments
 (0)