@@ -178,42 +178,13 @@ def onnx_embed_image(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutpu
178178 assert self .processor is not None , "Processor is not initialized"
179179 processed = self .processor (image_files )
180180
181- # Handle nested structure (with image splitting)
181+ # Dispatch to appropriate handler based on structure.
182+ # ColModernVBERT processors divides the original image into
183+ # subimages and processes them separately.
182184 if isinstance (processed [0 ], list ):
183- # processed = [[img1_patches], [img2_patches], ...]
184- # Need shape: (batch_size, max_patches, C, H, W)
185-
186- patch_counts = [len (patches ) for patches in processed ]
187- max_patches = max (patch_counts )
188-
189- # Get dimensions from first patch
190- C , H , W = processed [0 ][0 ].shape
191-
192- # Create padded array
193- batch_size = len (processed )
194- encoded = np .zeros ((batch_size , max_patches , C , H , W ), dtype = processed [0 ][0 ].dtype )
195-
196- # Create attention mask (1 for real patches, 0 for padding)
197- attention_mask = np .zeros ((batch_size , max_patches ), dtype = np .int64 )
198-
199- # Fill in patches and attention mask
200- for i , patches in enumerate (processed ):
201- for j , patch in enumerate (patches ):
202- encoded [i , j ] = patch
203- attention_mask [i , j ] = 1
204-
205- # Track actual patch counts for later use
206- metadata = {"patch_counts" : patch_counts }
185+ encoded , attention_mask , metadata = self ._process_nested_patches (processed )
207186 else :
208- # Flat structure (no splitting) - still need batch dimension
209- # Shape: (batch_size, 1, C, H, W)
210- encoded = np .array (processed )
211- if len (encoded .shape ) == 4 : # (batch_size, C, H, W)
212- encoded = encoded [:, np .newaxis , ...] # Add num_patches=1 dimension
213-
214- # All patches are real (no padding)
215- attention_mask = np .ones ((len (images ), encoded .shape [1 ]), dtype = np .int64 )
216- metadata = {"patch_counts" : [encoded .shape [1 ]] * len (images )}
187+ encoded , attention_mask , metadata = self ._process_flat_images (processed , len (images ))
217188
218189 onnx_input = {"pixel_values" : encoded , "attention_mask" : attention_mask }
219190 onnx_input = self ._preprocess_onnx_image_input (onnx_input , ** kwargs )
@@ -225,6 +196,105 @@ def onnx_embed_image(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutpu
225196 metadata = metadata ,
226197 )
227198
199+ def _process_nested_patches (
200+ self , processed : list [list [NumpyArray ]]
201+ ) -> tuple [NumpyArray , NumpyArray , dict [str , Any ]]:
202+ """
203+ Process nested image patches (from ImageSplitter).
204+
205+ Args:
206+ processed: List of patch lists, one per image [[img1_patches], [img2_patches], ...]
207+
208+ Returns:
209+ tuple: (encoded array, attention_mask, metadata)
210+ - encoded: (batch_size, max_patches, C, H, W)
211+ - attention_mask: (batch_size, max_patches) with 1 for real patches, 0 for padding
212+ - metadata: Dict with 'patch_counts' key
213+ """
214+ patch_counts = [len (patches ) for patches in processed ]
215+ max_patches = max (patch_counts )
216+
217+ # Get dimensions from first patch
218+ C , H , W = processed [0 ][0 ].shape
219+ batch_size = len (processed )
220+
221+ # Create padded array
222+ encoded = np .zeros ((batch_size , max_patches , C , H , W ), dtype = processed [0 ][0 ].dtype )
223+
224+ # Create attention mask (1 for real patches, 0 for padding)
225+ attention_mask = np .zeros ((batch_size , max_patches ), dtype = np .int64 )
226+
227+ # Fill in patches and attention mask
228+ for i , patches in enumerate (processed ):
229+ for j , patch in enumerate (patches ):
230+ encoded [i , j ] = patch
231+ attention_mask [i , j ] = 1
232+
233+ metadata = {"patch_counts" : patch_counts }
234+ return encoded , attention_mask , metadata
235+
236+ def _process_flat_images (
237+ self , processed : list [NumpyArray ], num_images : int
238+ ) -> tuple [NumpyArray , NumpyArray , dict [str , Any ]]:
239+ """
240+ Process flat image arrays (from standard processors like SiglipImageProcessor).
241+
242+ For models expecting 5D input (Idefics3-based), adds patch dimension.
243+ For models expecting 4D input, keeps original shape.
244+
245+ Args:
246+ processed: List of image arrays
247+ num_images: Number of images being processed
248+
249+ Returns:
250+ tuple: (encoded array, attention_mask, metadata)
251+ - encoded: (batch_size, C, H, W) for 4D models OR (batch_size, 1, C, H, W) for 5D models
252+ - attention_mask: (batch_size, 1) with all ones
253+ - metadata: Dict with 'patch_counts' key
254+ """
255+ encoded = np .array (processed )
256+
257+ # Check if model needs patch dimension based on ONNX signature
258+ if len (encoded .shape ) == 4 and self ._needs_patch_dimension ():
259+ # Add patch dimension for Idefics3-based models: (batch, 1, C, H, W)
260+ encoded = encoded [:, np .newaxis , ...]
261+
262+ # Determine attention mask shape based on final tensor shape
263+ if len (encoded .shape ) == 5 :
264+ # 5D tensor: attention_mask shape is (batch, num_patches)
265+ attention_mask = np .ones ((num_images , encoded .shape [1 ]), dtype = np .int64 )
266+ metadata = {"patch_counts" : [encoded .shape [1 ]] * num_images }
267+ else :
268+ # 4D tensor: attention_mask shape is (batch, 1)
269+ attention_mask = np .ones ((num_images , 1 ), dtype = np .int64 )
270+ metadata = {"patch_counts" : [1 ] * num_images }
271+
272+ return encoded , attention_mask , metadata
273+
274+ def _needs_patch_dimension (self ) -> bool :
275+ """
276+ Determine if this model needs the patch dimension by checking ONNX input shape.
277+
278+ Idefics3-based models (like ColModernVBERT) need 5D tensors (batch_size, patch_count, C, H, W).
279+ Earlier models (like ColPali v1.3) need 4D tensors (batch_size, C, H, W).
280+
281+ Returns:
282+ bool: True if pixel_values input has 5 dimensions, False if 4 dimensions
283+ """
284+ if not hasattr (self , "model" ) or self .model is None :
285+ return False
286+
287+ # Get pixel_values input metadata
288+ for input_meta in self .model .get_inputs ():
289+ if input_meta .name == "pixel_values" :
290+ # input_meta.shape is a list like
291+ # ['batch_size', 'sequence_length', 'num_channels', 'height', 'width']
292+ # or ['batch_size', 'num_channels', 'height', 'width']
293+ return len (input_meta .shape ) == 5
294+
295+ # Default to False for backward compatibility
296+ return False
297+
228298 def _embed_images (
229299 self ,
230300 model_name : str ,
0 commit comments