Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,23 @@ embeddings = list(model.embed(documents))

```

Dense text embedding can also be extended with models which are not in the list of supported models.

```python
from fastembed import TextEmbedding
from fastembed.common.model_description import PoolingType, ModelSource

TextEmbedding.add_custom_model(
model="intfloat/multilingual-e5-small",
pooling=PoolingType.MEAN,
normalization=True,
sources=ModelSource(hf="intfloat/multilingual-e5-small"), # can be used with an `url` to load files from a private storage
dim=384,
model_file="onnx/model.onnx", # can be used to load an already supported model with another optimization or quantization, e.g. onnx/model_O4.onnx
)
model = TextEmbedding(model_name="intfloat/multilingual-e5-small")
embeddings = list(model.embed(documents))
```


### 🔱 Sparse text embeddings
Expand Down Expand Up @@ -137,6 +154,27 @@ embeddings = list(model.embed(images))
# ]
```

### Late interaction multimodal models (ColPali)

```python
from fastembed import LateInteractionMultimodalEmbedding

doc_images = [
"./path/to/qdrant_pdf_doc_1_screenshot.jpg",
"./path/to/colpali_pdf_doc_2_screenshot.jpg",
]

query = "What is Qdrant?"

model = LateInteractionMultimodalEmbedding(model_name="Qdrant/colpali-v1.3-fp16")
doc_images_embeddings = list(model.embed_image(doc_images))
# shape (2, 1030, 128)
# [array([[-0.03353882, -0.02090454, ..., -0.15576172, -0.07678223]], dtype=float32)]
query_embedding = model.embed_text(query)
# shape (1, 20, 128)
# [array([[-0.00218201, 0.14758301, ..., -0.02207947, 0.16833496]], dtype=float32)]
```

### 🔄 Rerankers
```python
from fastembed.rerank.cross_encoder import TextCrossEncoder
Expand Down
5 changes: 2 additions & 3 deletions fastembed/late_interaction_multimodal/colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,11 @@ def _preprocess_onnx_image_input(
Returns:
Dict[str, NumpyArray]: ONNX input with text placeholders.
"""

onnx_input["input_ids"] = np.array(
[self.EMPTY_TEXT_PLACEHOLDER for _ in onnx_input["input_ids"]]
[self.EMPTY_TEXT_PLACEHOLDER for _ in onnx_input["pixel_values"]]
)
onnx_input["attention_mask"] = np.array(
[self.EVEN_ATTENTION_MASK for _ in onnx_input["input_ids"]]
[self.EVEN_ATTENTION_MASK for _ in onnx_input["pixel_values"]]
)
return onnx_input

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def _load_onnx_model(
cuda=cuda,
device_id=device_id,
)
assert self.tokenizer is not None
self.tokenizer, self.special_token_to_id = load_tokenizer(model_dir=model_dir)
assert self.tokenizer is not None
self.processor = load_preprocessor(model_dir=model_dir)

def load_onnx_model(self) -> None:
Expand Down Expand Up @@ -159,10 +159,6 @@ def _embed_documents(
for batch in pool.ordered_map(iter_batch(documents, batch_size), **params):
yield from self._post_process_onnx_text_output(batch) # type: ignore

def _build_onnx_image_input(self, encoded: NumpyArray) -> dict[str, NumpyArray]:
input_name = self.model.get_inputs()[0].name # type: ignore[union-attr]
return {input_name: encoded}

def onnx_embed_image(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputContext:
with contextlib.ExitStack():
image_files = [
Expand All @@ -171,7 +167,7 @@ def onnx_embed_image(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutpu
]
assert self.processor is not None, "Processor is not initialized"
encoded = np.array(self.processor(image_files))
onnx_input = self._build_onnx_image_input(encoded)
onnx_input = {"pixel_values": encoded}
onnx_input = self._preprocess_onnx_image_input(onnx_input, **kwargs)
model_output = self.model.run(None, onnx_input) # type: ignore[union-attr]
embeddings = model_output[0].reshape(len(images), -1)
Expand Down