Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fastembed/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

OnnxProvider: TypeAlias = Union[str, tuple[str, dict[Any, Any]]]
NumpyArray = Union[
NDArray[np.float64],
NDArray[np.float32],
NDArray[np.float16],
NDArray[np.int8],
Expand Down
6 changes: 3 additions & 3 deletions fastembed/late_interaction/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _post_process_onnx_output(
self, output: OnnxOutputContext, is_doc: bool = True
) -> Iterable[NumpyArray]:
if not is_doc:
return output.model_output.astype(np.float32)
return output.model_output

if output.input_ids is None or output.attention_mask is None:
raise ValueError(
Expand All @@ -58,11 +58,11 @@ def _post_process_onnx_output(
if token_id in self.skip_list or token_id == self.pad_token_id:
output.attention_mask[i, j] = 0

output.model_output *= np.expand_dims(output.attention_mask, 2).astype(np.float32)
output.model_output *= np.expand_dims(output.attention_mask, 2)
norm = np.linalg.norm(output.model_output, ord=2, axis=2, keepdims=True)
norm_clamped = np.maximum(norm, 1e-12)
output.model_output /= norm_clamped
return output.model_output.astype(np.float32)
return output.model_output

def _preprocess_onnx_input(
self, onnx_input: dict[str, NumpyArray], is_doc: bool = True, **kwargs: Any
Expand Down
4 changes: 2 additions & 2 deletions fastembed/late_interaction_multimodal/colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _post_process_onnx_image_output(
assert self.model_description.dim is not None, "Model dim is not defined"
return output.model_output.reshape(
output.model_output.shape[0], -1, self.model_description.dim
).astype(np.float32)
)

def _post_process_onnx_text_output(
self,
Expand All @@ -157,7 +157,7 @@ def _post_process_onnx_text_output(
Returns:
Iterable[NumpyArray]: Post-processed output as NumPy arrays.
"""
return output.model_output.astype(np.float32)
return output.model_output

def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]:
texts_query: list[str] = []
Expand Down
3 changes: 1 addition & 2 deletions fastembed/text/onnx_embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Iterable, Optional, Sequence, Type, Union

import numpy as np
from fastembed.common.types import NumpyArray, OnnxProvider
from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.common.utils import define_cache_dir, normalize
Expand Down Expand Up @@ -309,7 +308,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Numpy
processed_embeddings = embeddings
else:
raise ValueError(f"Unsupported embedding shape: {embeddings.shape}")
return normalize(processed_embeddings).astype(np.float32)
return normalize(processed_embeddings)

def load_onnx_model(self) -> None:
self._load_onnx_model(
Expand Down
2 changes: 1 addition & 1 deletion fastembed/text/pooled_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Numpy

embeddings = output.model_output
attn_mask = output.attention_mask
return self.mean_pooling(embeddings, attn_mask).astype(np.float32)
return self.mean_pooling(embeddings, attn_mask)


class PooledEmbeddingWorker(OnnxTextEmbeddingWorker):
Expand Down
3 changes: 1 addition & 2 deletions fastembed/text/pooled_normalized_embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Iterable, Type

import numpy as np

from fastembed.common.types import NumpyArray
from fastembed.common.onnx_model import OnnxOutputContext
Expand Down Expand Up @@ -144,7 +143,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Numpy

embeddings = output.model_output
attn_mask = output.attention_mask
return normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)
return normalize(self.mean_pooling(embeddings, attn_mask))


class PooledNormalizedEmbeddingWorker(OnnxTextEmbeddingWorker):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "fastembed"
version = "0.5.1"
version = "0.6.0"
description = "Fast, light, accurate library built for retrieval embedding generation"
authors = ["Qdrant Team <[email protected]>", "NirantK <[email protected]>"]
license = "Apache License"
Expand Down
14 changes: 5 additions & 9 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def test_mock_add_custom_models():
source = ModelSource(hf="artificial")

num_tokens = 10
dummy_pooled_embedding = np.random.random((1, dim)).astype(np.float32)
dummy_token_embedding = np.random.random((1, num_tokens, dim)).astype(np.float32)
dummy_pooled_embedding = np.random.random((1, dim))
dummy_token_embedding = np.random.random((1, num_tokens, dim))
dummy_attention_mask = np.ones((1, num_tokens)).astype(np.int64)

dummy_token_output = OnnxOutputContext(
Expand All @@ -91,15 +91,11 @@ def test_mock_add_custom_models():
expected_output = {
f"{PoolingType.MEAN.lower()}-normalized": normalize(
mean_pooling(dummy_token_embedding, dummy_attention_mask)
).astype(np.float32),
f"{PoolingType.MEAN.lower()}": mean_pooling(dummy_token_embedding, dummy_attention_mask),
f"{PoolingType.CLS.lower()}-normalized": normalize(dummy_token_embedding[:, 0]).astype(
np.float32
),
f"{PoolingType.MEAN.lower()}": mean_pooling(dummy_token_embedding, dummy_attention_mask),
f"{PoolingType.CLS.lower()}-normalized": normalize(dummy_token_embedding[:, 0]),
f"{PoolingType.CLS.lower()}": dummy_token_embedding[:, 0],
f"{PoolingType.DISABLED.lower()}-normalized": normalize(dummy_pooled_embedding).astype(
np.float32
),
f"{PoolingType.DISABLED.lower()}-normalized": normalize(dummy_pooled_embedding),
f"{PoolingType.DISABLED.lower()}": dummy_pooled_embedding,
}

Expand Down