79
79
AutoWeightsLoader ,
80
80
PPMissingLayer ,
81
81
WeightsMapper ,
82
- flatten_bn ,
83
82
make_empty_intermediate_tensors_factory ,
84
83
maybe_prefix ,
85
84
)
@@ -347,54 +346,37 @@ def _get_prompt_updates(
347
346
348
347
def _get_mm_fields_config (
349
348
self ,
350
- hf_inputs ,
351
- hf_processor_mm_kwargs ,
352
- num_image_patches : torch .Tensor = None ,
353
- ):
349
+ hf_inputs : BatchFeature ,
350
+ hf_processor_mm_kwargs : Mapping [str , object ],
351
+ ) -> Mapping [str , MultiModalFieldConfig ]:
354
352
# HF Processors always return a mask but vLLM doesn't need it
355
353
hf_inputs .pop ("attention_mask" , None )
354
+ num_image_patches = hf_inputs .get ("num_image_patches" )
356
355
mm_fields = {
357
356
key : MultiModalFieldConfig .flat_from_sizes ("image" , num_image_patches )
358
357
for key in hf_inputs
359
358
}
360
359
mm_fields ["image_embeds" ] = MultiModalFieldConfig .flat_from_sizes (
361
360
"image" , num_image_patches
362
361
)
362
+
363
+ # Keep these as batched, as they always have batch size as first dim
364
+ mm_fields ["image_grid_thw" ] = MultiModalFieldConfig .batched ("image" )
365
+ mm_fields ["video_grid_thw" ] = MultiModalFieldConfig .batched ("image" )
363
366
mm_fields ["num_image_patches" ] = MultiModalFieldConfig .batched ("image" )
364
367
return mm_fields
365
368
366
- def _apply_hf_processor_text_mm (
369
+ def _get_hf_mm_data (
367
370
self ,
368
- prompt_text : str ,
369
371
mm_items : MultiModalDataItems ,
370
- hf_processor_mm_kwargs : Mapping [str , object ],
371
- tokenization_kwargs : Mapping [str , object ],
372
- ) -> tuple [list [int ], BatchFeature , bool ]:
372
+ ) -> tuple [Mapping [str , object ], Mapping [str , object ]]:
373
373
"""
374
- Apply the HF processor on the prompt text and multi-modal data
375
- together.
376
-
377
- In addition, return whether prompt replacements have been applied.
374
+ In contrast to the base class, this method always adds
375
+ `return_mm_token_type_ids` to the processor data
378
376
"""
379
- processor_data , passthrough_data = self ._get_hf_mm_data (mm_items )
377
+ processor_data , passthrough_data = super () ._get_hf_mm_data (mm_items )
380
378
processor_data ["return_mm_token_type_ids" ] = True
381
-
382
- processed_data = self ._call_hf_processor (
383
- prompt = prompt_text ,
384
- mm_data = processor_data ,
385
- mm_kwargs = hf_processor_mm_kwargs ,
386
- tok_kwargs = tokenization_kwargs ,
387
- )
388
- processed_data .update (passthrough_data )
389
-
390
- (prompt_ids ,) = processed_data .pop ("input_ids" ).tolist ()
391
- mm_token_type_ids = (
392
- processed_data .pop ("mm_token_type_ids" )
393
- if "mm_token_type_ids" in processed_data
394
- else processed_data .pop ("token_type_ids" )
395
- ) # for gemma3 only
396
-
397
- return prompt_ids , processed_data , mm_token_type_ids
379
+ return processor_data , passthrough_data
398
380
399
381
def apply (
400
382
self ,
@@ -421,18 +403,28 @@ def apply(
421
403
# into string
422
404
prompt = hf_processor .decode (prompt )
423
405
424
- (prompt_ids , processed_data , mm_token_type_ids ) = (
425
- self ._apply_hf_processor_text_mm (
426
- prompt_text = prompt ,
427
- mm_items = mm_items ,
428
- hf_processor_mm_kwargs = hf_processor_mm_kwargs ,
429
- tokenization_kwargs = tokenization_kwargs ,
430
- )
406
+ # Bypass cached processor and always apply to the full set of mm inputs
407
+ # NOTE: we can't just set caching=False because base class method
408
+ # transforms outputs to `MultiModalKwargs` which is not going to
409
+ # work for Transformers. We have a lot of logic tied to
410
+ # `mm_tokens_per_modality` below
411
+ prompt_ids , processed_data , _ = self ._apply_hf_processor_text_mm (
412
+ prompt_text = prompt ,
413
+ mm_items = mm_items ,
414
+ hf_processor_mm_kwargs = hf_processor_mm_kwargs ,
415
+ tokenization_kwargs = tokenization_kwargs ,
431
416
)
432
417
433
- # HF processor will return `mm_token_type_ids` from which
434
- # we can infer mm_placeholders. Until then hardcode to make code run
435
- # Below tested on Llava. Prompts and `mm_token_type_ids` are always bs=1
418
+ # For gemma3 we check `token_type_ids` as the key
419
+ token_type_key = (
420
+ "mm_token_type_ids"
421
+ if "mm_token_type_ids" in processed_data
422
+ else "token_type_ids"
423
+ )
424
+ mm_token_type_ids = processed_data .pop (token_type_key )
425
+
426
+ # We can infer vLLM style placeholder from token type ids, if we split
427
+ # it for each input `mm_data`.
436
428
mm_positions = torch .where (mm_token_type_ids == 1 )[1 ]
437
429
images = mm_items .get_items ("image" , ImageProcessorItems )
438
430
multimodal_config = self .info .ctx .model_config .multimodal_config
@@ -462,17 +454,12 @@ def apply(
462
454
]
463
455
mm_placeholders = {"image" : ranges }
464
456
465
- num_image_patches = (
466
- torch .tensor (mm_tokens_per_modality ["num_image_patches" ])
467
- if "num_image_patches" in mm_tokens_per_modality
468
- else None
457
+ processed_data ["num_image_patches" ] = torch .tensor (
458
+ mm_tokens_per_modality ["num_image_patches" ]
469
459
)
470
- processed_data ["num_image_patches" ] = num_image_patches
471
460
mm_kwargs = MultiModalKwargsItems .from_hf_inputs (
472
461
processed_data ,
473
- self ._get_mm_fields_config (
474
- processed_data , hf_processor_mm_kwargs , num_image_patches
475
- ),
462
+ self ._get_mm_fields_config (processed_data , hf_processor_mm_kwargs ),
476
463
)
477
464
478
465
# Use overrides if provided; fallback to data-dependent hashing.
@@ -531,8 +518,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
531
518
self .ignore_unexpected_suffixes .append (".bias" )
532
519
533
520
# Set correct attn and init on "meta" to delay allocating GPU tensors
534
- # TODO: @raushan, use the public `model.set_attn_implementation()`
535
- # method once its checks are fixed in Transformers.
536
521
self .text_config ._attn_implementation = "vllm"
537
522
with init_on_device_without_buffers ("meta" ):
538
523
self .model : PreTrainedModel = AutoModel .from_config (
@@ -844,17 +829,6 @@ def compute_logits(
844
829
return logits
845
830
846
831
847
- def flatten_and_concat (x : list [torch .Tensor ]) -> torch .Tensor :
848
- """Flatten until a list of tensors can be concatenated then do concat"""
849
-
850
- def _can_concat (x : list [torch .Tensor ]):
851
- return len (set (map (lambda _x : _x .shape [1 :], x ))) == 1
852
-
853
- if _can_concat (x ):
854
- return torch .concat (x )
855
- return flatten_and_concat (flatten_bn (x ))
856
-
857
-
858
832
@MULTIMODAL_REGISTRY .register_processor (
859
833
MultiModalProcessor ,
860
834
info = MultiModalProcessingInfo ,
@@ -935,9 +909,6 @@ def get_multimodal_embeddings(self, **kwargs):
935
909
vision_embeddings = self .model .get_image_features (pixel_values , ** kwargs )
936
910
937
911
if isinstance (vision_embeddings , torch .Tensor ):
938
- if isinstance (num_image_patches , list ):
939
- num_image_patches = torch .cat (num_image_patches )
940
-
941
912
if vision_embeddings .ndim == 2 :
942
913
vision_embeddings = vision_embeddings .unsqueeze (0 )
943
914
0 commit comments