diff --git a/fastembed/image/onnx_embedding.py b/fastembed/image/onnx_embedding.py index 5647c2ff..47beae22 100644 --- a/fastembed/image/onnx_embedding.py +++ b/fastembed/image/onnx_embedding.py @@ -53,6 +53,17 @@ }, "model_file": "model.onnx", }, + { + "model": "jinaai/jina-clip-v1", + "dim": 768, + "description": "Image embeddings, Multimodal (text&image), 2024 year", + "license": "apache-2.0", + "size_in_GB": 0.34, + "sources": { + "hf": "jinaai/jina-clip-v1", + }, + "model_file": "onnx/vision_model.onnx", + }, ] diff --git a/fastembed/image/transform/functional.py b/fastembed/image/transform/functional.py index 70da2a22..afefe4be 100644 --- a/fastembed/image/transform/functional.py +++ b/fastembed/image/transform/functional.py @@ -62,8 +62,8 @@ def center_crop( def normalize( image: np.ndarray, - mean=Union[float, np.ndarray], - std=Union[float, np.ndarray], + mean: Union[float, np.ndarray], + std: Union[float, np.ndarray], ) -> np.ndarray: if not isinstance(image, np.ndarray): raise ValueError("image must be a numpy array") @@ -96,10 +96,10 @@ def normalize( def resize( - image: Image, + image: Image.Image, size: Union[int, tuple[int, int]], - resample: Image.Resampling = Image.Resampling.BILINEAR, -) -> Image: + resample: Union[int, Image.Resampling] = Image.Resampling.BILINEAR, +) -> Image.Image: if isinstance(size, tuple): return image.resize(size, resample) @@ -122,3 +122,29 @@ def pil2ndarray(image: Union[Image.Image, np.ndarray]): if isinstance(image, Image.Image): return np.asarray(image).transpose((2, 0, 1)) return image + + +def pad2square( + image: Image.Image, + size: int, + fill_color: Union[str, int, tuple[int, ...]] = 0, +) -> Image.Image: + height, width = image.height, image.width + + left, right = 0, width + top, bottom = 0, height + + crop_required = False + if width > size: + left = (width - size) // 2 + right = left + size + crop_required = True + + if height > size: + top = (height - size) // 2 + bottom = top + size + crop_required = True + + new_image = Image.new(mode="RGB", size=(size, size), color=fill_color) + new_image.paste(image.crop((left, top, right, bottom)) if crop_required else image) + return new_image diff --git a/fastembed/image/transform/operators.py b/fastembed/image/transform/operators.py index 2b943dbb..bac65e08 100644 --- a/fastembed/image/transform/operators.py +++ b/fastembed/image/transform/operators.py @@ -1,4 +1,4 @@ -from typing import Any, Union +from typing import Any, Union, Optional import numpy as np from PIL import Image @@ -10,6 +10,7 @@ pil2ndarray, rescale, resize, + pad2square, ) @@ -66,6 +67,21 @@ def __call__(self, images: list[Union[Image.Image, np.ndarray]]) -> list[np.ndar return [pil2ndarray(image) for image in images] +class PadtoSquare(Transform): + def __init__( + self, + size: int, + fill_color: Optional[Union[str, int, tuple[int, ...]]] = None, + ): + self.size = size + self.fill_color = fill_color + + def __call__(self, images: list[Image.Image]) -> list[Image.Image]: + return [ + pad2square(image=image, size=self.size, fill_color=self.fill_color) for image in images + ] + + class Compose: def __init__(self, transforms: list[Transform]): self.transforms = transforms @@ -85,14 +101,20 @@ def from_config(cls, config: dict[str, Any]) -> "Compose": Valid keys: - do_resize + - resize_mode - size + - fill_color - do_center_crop - crop_size - do_rescale - rescale_factor - do_normalize - image_mean + - mean - image_std + - std + - resample + - interpolation Valid size keys (nested): - {"height", "width"} - {"shortest_edge"} @@ -103,6 +125,7 @@ def from_config(cls, config: dict[str, Any]) -> "Compose": transforms = [] cls._get_convert_to_rgb(transforms, config) cls._get_resize(transforms, config) + cls._get_pad2square(transforms, config) cls._get_center_crop(transforms, config) cls._get_pil2ndarray(transforms, config) cls._get_rescale(transforms, config) @@ -113,8 +136,8 @@ def from_config(cls, config: dict[str, Any]) -> "Compose": def _get_convert_to_rgb(transforms: list[Transform], config: dict[str, Any]): transforms.append(ConvertToRGB()) - @staticmethod - def _get_resize(transforms: list[Transform], config: dict[str, Any]): + @classmethod + def _get_resize(cls, transforms: list[Transform], config: dict[str, Any]): mode = config.get("image_processor_type", "CLIPImageProcessor") if mode == "CLIPImageProcessor": if config.get("do_resize", False): @@ -157,6 +180,24 @@ def _get_resize(transforms: list[Transform], config: dict[str, Any]): resample=config.get("resample", Image.Resampling.BICUBIC), ) ) + elif mode == "JinaCLIPImageProcessor": + interpolation = config.get("interpolation") + if isinstance(interpolation, str): + resample = cls._interpolation_resolver(interpolation) + else: + resample = interpolation or Image.Resampling.BICUBIC + + if "size" in config: + resize_mode = config.get("resize_mode", "shortest") + if resize_mode == "shortest": + transforms.append( + Resize( + size=config["size"], + resample=resample, + ) + ) + else: + raise ValueError(f"Preprocessor {mode} is not supported") @staticmethod def _get_center_crop(transforms: list[Transform], config: dict[str, Any]): @@ -173,6 +214,8 @@ def _get_center_crop(transforms: list[Transform], config: dict[str, Any]): transforms.append(CenterCrop(size=crop_size)) elif mode == "ConvNextFeatureExtractor": pass + elif mode == "JinaCLIPImageProcessor": + pass else: raise ValueError(f"Preprocessor {mode} is not supported") @@ -190,3 +233,36 @@ def _get_rescale(transforms: list[Transform], config: dict[str, Any]): def _get_normalize(transforms: list[Transform], config: dict[str, Any]): if config.get("do_normalize", False): transforms.append(Normalize(mean=config["image_mean"], std=config["image_std"])) + elif "mean" in config and "std" in config: + transforms.append(Normalize(mean=config["mean"], std=config["std"])) + + @staticmethod + def _get_pad2square(transforms: list[Transform], config: dict[str, Any]): + mode = config.get("image_processor_type", "CLIPImageProcessor") + if mode == "CLIPImageProcessor": + pass + elif mode == "ConvNextFeatureExtractor": + pass + elif mode == "JinaCLIPImageProcessor": + transforms.append( + PadtoSquare( + size=config["size"], + fill_color=config.get("fill_color", 0), + ) + ) + + @staticmethod + def _interpolation_resolver(resample: Optional[str] = None) -> Image.Resampling: + interpolation_map = { + "nearest": Image.Resampling.NEAREST, + "lanczos": Image.Resampling.LANCZOS, + "bilinear": Image.Resampling.BILINEAR, + "bicubic": Image.Resampling.BICUBIC, + "box": Image.Resampling.BOX, + "hamming": Image.Resampling.HAMMING, + } + + if resample and (method := interpolation_map.get(resample.lower())): + return method + + raise ValueError(f"Unknown interpolation method: {resample}") diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index b16d39d9..e4d657c7 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -164,6 +164,17 @@ }, "model_file": "onnx/model.onnx", }, + { + "model": "jinaai/jina-clip-v1", + "dim": 768, + "description": "Text embeddings, Multimodal (text&image), English, Prefixes for queries/documents: not necessary, 2024 year", + "license": "apache-2.0", + "size_in_GB": 0.55, + "sources": { + "hf": "jinaai/jina-clip-v1", + }, + "model_file": "onnx/text_model.onnx", + }, ] @@ -285,7 +296,13 @@ def _preprocess_onnx_input( def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]: embeddings = output.model_output - return normalize(embeddings[:, 0]).astype(np.float32) + if embeddings.ndim == 3: # (batch_size, seq_len, embedding_dim) + processed_embeddings = embeddings[:, 0] + elif embeddings.ndim == 2: # (batch_size, embedding_dim) + processed_embeddings = embeddings + else: + raise ValueError(f"Unsupported embedding shape: {embeddings.shape}") + return normalize(processed_embeddings).astype(np.float32) def load_onnx_model(self) -> None: self._load_onnx_model( diff --git a/fastembed/text/text_embedding.py b/fastembed/text/text_embedding.py index 81e8d6f0..960d68f7 100644 --- a/fastembed/text/text_embedding.py +++ b/fastembed/text/text_embedding.py @@ -78,7 +78,7 @@ def __init__( return raise ValueError( - f"Model {model_name} is not supported in TextEmbedding." + f"Model {model_name} is not supported in TextEmbedding. " "Please check the supported models using `TextEmbedding.list_supported_models()`" ) diff --git a/tests/test_image_onnx_embeddings.py b/tests/test_image_onnx_embeddings.py index 78194caf..a5fb8e36 100644 --- a/tests/test_image_onnx_embeddings.py +++ b/tests/test_image_onnx_embeddings.py @@ -21,6 +21,9 @@ "Qdrant/Unicom-ViT-B-32": np.array( [0.0418, 0.0550, 0.0003, 0.0253, -0.0185, 0.0016, -0.0368, -0.0402, -0.0891, -0.0186] ), + "jinaai/jina-clip-v1": np.array( + [-0.029, 0.0216, 0.0396, 0.0283, -0.0023, 0.0151, 0.011, -0.0235, 0.0251, -0.0343] + ), } diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index a40794be..ac2c41a3 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -65,6 +65,7 @@ "snowflake/snowflake-arctic-embed-l": np.array([0.0189, -0.0673, 0.0183, 0.0124, 0.0146]), "Qdrant/clip-ViT-B-32-text": np.array([0.0083, 0.0103, -0.0138, 0.0199, -0.0069]), "thenlper/gte-base": np.array([0.0038, 0.0355, 0.0181, 0.0092, 0.0654]), + "jinaai/jina-clip-v1": np.array([-0.0862, -0.0101, -0.0056, 0.0375, -0.0472]), }