diff --git a/README.md b/README.md index 61c1dd32..88ef031b 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,23 @@ embeddings = list(model.embed(documents)) ``` +Dense text embedding can also be extended with models which are not in the list of supported models. + +```python +from fastembed import TextEmbedding +from fastembed.common.model_description import PoolingType, ModelSource + +TextEmbedding.add_custom_model( + model="intfloat/multilingual-e5-small", + pooling=PoolingType.MEAN, + normalization=True, + sources=ModelSource(hf="intfloat/multilingual-e5-small"), # can be used with an `url` to load files from a private storage + dim=384, + model_file="onnx/model.onnx", # can be used to load an already supported model with another optimization or quantization, e.g. onnx/model_O4.onnx +) +model = TextEmbedding(model_name="intfloat/multilingual-e5-small") +embeddings = list(model.embed(documents)) +``` ### 🔱 Sparse text embeddings @@ -137,6 +154,27 @@ embeddings = list(model.embed(images)) # ] ``` +### Late interaction multimodal models (ColPali) + +```python +from fastembed import LateInteractionMultimodalEmbedding + +doc_images = [ + "./path/to/qdrant_pdf_doc_1_screenshot.jpg", + "./path/to/colpali_pdf_doc_2_screenshot.jpg", +] + +query = "What is Qdrant?" + +model = LateInteractionMultimodalEmbedding(model_name="Qdrant/colpali-v1.3-fp16") +doc_images_embeddings = list(model.embed_image(doc_images)) +# shape (2, 1030, 128) +# [array([[-0.03353882, -0.02090454, ..., -0.15576172, -0.07678223]], dtype=float32)] +query_embedding = model.embed_text(query) +# shape (1, 20, 128) +# [array([[-0.00218201, 0.14758301, ..., -0.02207947, 0.16833496]], dtype=float32)] +``` + ### 🔄 Rerankers ```python from fastembed.rerank.cross_encoder import TextCrossEncoder diff --git a/fastembed/late_interaction_multimodal/colpali.py b/fastembed/late_interaction_multimodal/colpali.py index 731c902b..0193bed9 100644 --- a/fastembed/late_interaction_multimodal/colpali.py +++ b/fastembed/late_interaction_multimodal/colpali.py @@ -197,12 +197,11 @@ def _preprocess_onnx_image_input( Returns: Dict[str, NumpyArray]: ONNX input with text placeholders. """ - onnx_input["input_ids"] = np.array( - [self.EMPTY_TEXT_PLACEHOLDER for _ in onnx_input["input_ids"]] + [self.EMPTY_TEXT_PLACEHOLDER for _ in onnx_input["pixel_values"]] ) onnx_input["attention_mask"] = np.array( - [self.EVEN_ATTENTION_MASK for _ in onnx_input["input_ids"]] + [self.EVEN_ATTENTION_MASK for _ in onnx_input["pixel_values"]] ) return onnx_input diff --git a/fastembed/late_interaction_multimodal/onnx_multimodal_model.py b/fastembed/late_interaction_multimodal/onnx_multimodal_model.py index e34b8c0e..089ba1b7 100644 --- a/fastembed/late_interaction_multimodal/onnx_multimodal_model.py +++ b/fastembed/late_interaction_multimodal/onnx_multimodal_model.py @@ -73,8 +73,8 @@ def _load_onnx_model( cuda=cuda, device_id=device_id, ) - assert self.tokenizer is not None self.tokenizer, self.special_token_to_id = load_tokenizer(model_dir=model_dir) + assert self.tokenizer is not None self.processor = load_preprocessor(model_dir=model_dir) def load_onnx_model(self) -> None: @@ -159,10 +159,6 @@ def _embed_documents( for batch in pool.ordered_map(iter_batch(documents, batch_size), **params): yield from self._post_process_onnx_text_output(batch) # type: ignore - def _build_onnx_image_input(self, encoded: NumpyArray) -> dict[str, NumpyArray]: - input_name = self.model.get_inputs()[0].name # type: ignore[union-attr] - return {input_name: encoded} - def onnx_embed_image(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputContext: with contextlib.ExitStack(): image_files = [ @@ -171,7 +167,7 @@ def onnx_embed_image(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutpu ] assert self.processor is not None, "Processor is not initialized" encoded = np.array(self.processor(image_files)) - onnx_input = self._build_onnx_image_input(encoded) + onnx_input = {"pixel_values": encoded} onnx_input = self._preprocess_onnx_image_input(onnx_input, **kwargs) model_output = self.model.run(None, onnx_input) # type: ignore[union-attr] embeddings = model_output[0].reshape(len(images), -1)