Skip to content

Commit 746693d

Browse files
Refactor onnx_embed_image
1 parent e0acfe7 commit 746693d

File tree

2 files changed

+108
-41
lines changed

2 files changed

+108
-41
lines changed

fastembed/late_interaction_multimodal/onnx_multimodal_model.py

Lines changed: 104 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -178,42 +178,13 @@ def onnx_embed_image(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutpu
178178
assert self.processor is not None, "Processor is not initialized"
179179
processed = self.processor(image_files)
180180

181-
# Handle nested structure (with image splitting)
181+
# Dispatch to appropriate handler based on structure.
182+
# ColModernVBERT processors divides the original image into
183+
# subimages and processes them separately.
182184
if isinstance(processed[0], list):
183-
# processed = [[img1_patches], [img2_patches], ...]
184-
# Need shape: (batch_size, max_patches, C, H, W)
185-
186-
patch_counts = [len(patches) for patches in processed]
187-
max_patches = max(patch_counts)
188-
189-
# Get dimensions from first patch
190-
C, H, W = processed[0][0].shape
191-
192-
# Create padded array
193-
batch_size = len(processed)
194-
encoded = np.zeros((batch_size, max_patches, C, H, W), dtype=processed[0][0].dtype)
195-
196-
# Create attention mask (1 for real patches, 0 for padding)
197-
attention_mask = np.zeros((batch_size, max_patches), dtype=np.int64)
198-
199-
# Fill in patches and attention mask
200-
for i, patches in enumerate(processed):
201-
for j, patch in enumerate(patches):
202-
encoded[i, j] = patch
203-
attention_mask[i, j] = 1
204-
205-
# Track actual patch counts for later use
206-
metadata = {"patch_counts": patch_counts}
185+
encoded, attention_mask, metadata = self._process_nested_patches(processed)
207186
else:
208-
# Flat structure (no splitting) - still need batch dimension
209-
# Shape: (batch_size, 1, C, H, W)
210-
encoded = np.array(processed)
211-
if len(encoded.shape) == 4: # (batch_size, C, H, W)
212-
encoded = encoded[:, np.newaxis, ...] # Add num_patches=1 dimension
213-
214-
# All patches are real (no padding)
215-
attention_mask = np.ones((len(images), encoded.shape[1]), dtype=np.int64)
216-
metadata = {"patch_counts": [encoded.shape[1]] * len(images)}
187+
encoded, attention_mask, metadata = self._process_flat_images(processed, len(images))
217188

218189
onnx_input = {"pixel_values": encoded, "attention_mask": attention_mask}
219190
onnx_input = self._preprocess_onnx_image_input(onnx_input, **kwargs)
@@ -225,6 +196,105 @@ def onnx_embed_image(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutpu
225196
metadata=metadata,
226197
)
227198

199+
def _process_nested_patches(
200+
self, processed: list[list[NumpyArray]]
201+
) -> tuple[NumpyArray, NumpyArray, dict[str, Any]]:
202+
"""
203+
Process nested image patches (from ImageSplitter).
204+
205+
Args:
206+
processed: List of patch lists, one per image [[img1_patches], [img2_patches], ...]
207+
208+
Returns:
209+
tuple: (encoded array, attention_mask, metadata)
210+
- encoded: (batch_size, max_patches, C, H, W)
211+
- attention_mask: (batch_size, max_patches) with 1 for real patches, 0 for padding
212+
- metadata: Dict with 'patch_counts' key
213+
"""
214+
patch_counts = [len(patches) for patches in processed]
215+
max_patches = max(patch_counts)
216+
217+
# Get dimensions from first patch
218+
C, H, W = processed[0][0].shape
219+
batch_size = len(processed)
220+
221+
# Create padded array
222+
encoded = np.zeros((batch_size, max_patches, C, H, W), dtype=processed[0][0].dtype)
223+
224+
# Create attention mask (1 for real patches, 0 for padding)
225+
attention_mask = np.zeros((batch_size, max_patches), dtype=np.int64)
226+
227+
# Fill in patches and attention mask
228+
for i, patches in enumerate(processed):
229+
for j, patch in enumerate(patches):
230+
encoded[i, j] = patch
231+
attention_mask[i, j] = 1
232+
233+
metadata = {"patch_counts": patch_counts}
234+
return encoded, attention_mask, metadata
235+
236+
def _process_flat_images(
237+
self, processed: list[NumpyArray], num_images: int
238+
) -> tuple[NumpyArray, NumpyArray, dict[str, Any]]:
239+
"""
240+
Process flat image arrays (from standard processors like SiglipImageProcessor).
241+
242+
For models expecting 5D input (Idefics3-based), adds patch dimension.
243+
For models expecting 4D input, keeps original shape.
244+
245+
Args:
246+
processed: List of image arrays
247+
num_images: Number of images being processed
248+
249+
Returns:
250+
tuple: (encoded array, attention_mask, metadata)
251+
- encoded: (batch_size, C, H, W) for 4D models OR (batch_size, 1, C, H, W) for 5D models
252+
- attention_mask: (batch_size, 1) with all ones
253+
- metadata: Dict with 'patch_counts' key
254+
"""
255+
encoded = np.array(processed)
256+
257+
# Check if model needs patch dimension based on ONNX signature
258+
if len(encoded.shape) == 4 and self._needs_patch_dimension():
259+
# Add patch dimension for Idefics3-based models: (batch, 1, C, H, W)
260+
encoded = encoded[:, np.newaxis, ...]
261+
262+
# Determine attention mask shape based on final tensor shape
263+
if len(encoded.shape) == 5:
264+
# 5D tensor: attention_mask shape is (batch, num_patches)
265+
attention_mask = np.ones((num_images, encoded.shape[1]), dtype=np.int64)
266+
metadata = {"patch_counts": [encoded.shape[1]] * num_images}
267+
else:
268+
# 4D tensor: attention_mask shape is (batch, 1)
269+
attention_mask = np.ones((num_images, 1), dtype=np.int64)
270+
metadata = {"patch_counts": [1] * num_images}
271+
272+
return encoded, attention_mask, metadata
273+
274+
def _needs_patch_dimension(self) -> bool:
275+
"""
276+
Determine if this model needs the patch dimension by checking ONNX input shape.
277+
278+
Idefics3-based models (like ColModernVBERT) need 5D tensors (batch_size, patch_count, C, H, W).
279+
Earlier models (like ColPali v1.3) need 4D tensors (batch_size, C, H, W).
280+
281+
Returns:
282+
bool: True if pixel_values input has 5 dimensions, False if 4 dimensions
283+
"""
284+
if not hasattr(self, "model") or self.model is None:
285+
return False
286+
287+
# Get pixel_values input metadata
288+
for input_meta in self.model.get_inputs():
289+
if input_meta.name == "pixel_values":
290+
# input_meta.shape is a list like
291+
# ['batch_size', 'sequence_length', 'num_channels', 'height', 'width']
292+
# or ['batch_size', 'num_channels', 'height', 'width']
293+
return len(input_meta.shape) == 5
294+
295+
# Default to False for backward compatibility
296+
return False
297+
228298
def _embed_images(
229299
self,
230300
model_name: str,

tests/test_late_interaction_multimodal.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,10 @@
4848
),
4949
"Qdrant/colmodernvbert": np.array(
5050
[
51-
[0.05, 0.0656, 0.0403, 0.1498, 0.1842, 0.0263, -0.1871],
52-
[-0.0566, -0.1403, 0.0065, -0.0285, 0.0903, -0.0149, 0.1069],
53-
[-0.1015, -0.0072, 0.0908, -0.0824, -0.0185, -0.0097, -0.0046],
54-
[-0.1233, -0.1081, -0.0234, -0.0033, 0.0598, 0.0993, 0.0985],
55-
[-0.0705, -0.1312, -0.0649, 0.0151, 0.0746, 0.0765, 0.1482],
56-
[0.0053, -0.1384, -0.0584, -0.0272, 0.1301, 0.0508, 0.1796],
57-
[0.0092, -0.1438, -0.0306, -0.0369, 0.1172, 0.037, 0.1334],
51+
[0.0541, 0.0677, 0.0392, 0.1494, 0.1855, 0.0275, -0.1835, -0.1025, -0.1204, -0.0835],
52+
[-0.0515, -0.1328, 0.0298, -0.0574, 0.0829, -0.0836, 0.0888, 0.0138, 0.0741, 0.0293],
53+
[-0.1114, -0.0506, 0.0666, -0.1064, -0.0229, -0.0486, -0.007, 0.0932, 0.0054, 0.1113],
54+
[0.2317, -0.0518, 0.0248, -0.0075, -0.078, 0.2073, -0.0912, -0.0622, -0.0203, 0.093]
5855
]
5956
),
6057
}

0 commit comments

Comments
 (0)