diff --git a/fastembed/common/types.py b/fastembed/common/types.py index 4b6a570d..a1adccbb 100644 --- a/fastembed/common/types.py +++ b/fastembed/common/types.py @@ -16,6 +16,7 @@ OnnxProvider: TypeAlias = Union[str, tuple[str, dict[Any, Any]]] NumpyArray = Union[ + NDArray[np.float64], NDArray[np.float32], NDArray[np.float16], NDArray[np.int8], diff --git a/fastembed/late_interaction/colbert.py b/fastembed/late_interaction/colbert.py index 841d3a73..eb9545b3 100644 --- a/fastembed/late_interaction/colbert.py +++ b/fastembed/late_interaction/colbert.py @@ -46,7 +46,7 @@ def _post_process_onnx_output( self, output: OnnxOutputContext, is_doc: bool = True ) -> Iterable[NumpyArray]: if not is_doc: - return output.model_output.astype(np.float32) + return output.model_output if output.input_ids is None or output.attention_mask is None: raise ValueError( @@ -58,11 +58,11 @@ def _post_process_onnx_output( if token_id in self.skip_list or token_id == self.pad_token_id: output.attention_mask[i, j] = 0 - output.model_output *= np.expand_dims(output.attention_mask, 2).astype(np.float32) + output.model_output *= np.expand_dims(output.attention_mask, 2) norm = np.linalg.norm(output.model_output, ord=2, axis=2, keepdims=True) norm_clamped = np.maximum(norm, 1e-12) output.model_output /= norm_clamped - return output.model_output.astype(np.float32) + return output.model_output def _preprocess_onnx_input( self, onnx_input: dict[str, NumpyArray], is_doc: bool = True, **kwargs: Any diff --git a/fastembed/late_interaction_multimodal/colpali.py b/fastembed/late_interaction_multimodal/colpali.py index 0193bed9..c43ff9d0 100644 --- a/fastembed/late_interaction_multimodal/colpali.py +++ b/fastembed/late_interaction_multimodal/colpali.py @@ -142,7 +142,7 @@ def _post_process_onnx_image_output( assert self.model_description.dim is not None, "Model dim is not defined" return output.model_output.reshape( output.model_output.shape[0], -1, self.model_description.dim - ).astype(np.float32) + ) def _post_process_onnx_text_output( self, @@ -157,7 +157,7 @@ def _post_process_onnx_text_output( Returns: Iterable[NumpyArray]: Post-processed output as NumPy arrays. """ - return output.model_output.astype(np.float32) + return output.model_output def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: texts_query: list[str] = [] diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index d272c37a..158bc09d 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -1,6 +1,5 @@ from typing import Any, Iterable, Optional, Sequence, Type, Union -import numpy as np from fastembed.common.types import NumpyArray, OnnxProvider from fastembed.common.onnx_model import OnnxOutputContext from fastembed.common.utils import define_cache_dir, normalize @@ -309,7 +308,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Numpy processed_embeddings = embeddings else: raise ValueError(f"Unsupported embedding shape: {embeddings.shape}") - return normalize(processed_embeddings).astype(np.float32) + return normalize(processed_embeddings) def load_onnx_model(self) -> None: self._load_onnx_model( diff --git a/fastembed/text/pooled_embedding.py b/fastembed/text/pooled_embedding.py index 1dc8c9f5..3c356577 100644 --- a/fastembed/text/pooled_embedding.py +++ b/fastembed/text/pooled_embedding.py @@ -115,7 +115,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Numpy embeddings = output.model_output attn_mask = output.attention_mask - return self.mean_pooling(embeddings, attn_mask).astype(np.float32) + return self.mean_pooling(embeddings, attn_mask) class PooledEmbeddingWorker(OnnxTextEmbeddingWorker): diff --git a/fastembed/text/pooled_normalized_embedding.py b/fastembed/text/pooled_normalized_embedding.py index ed825eca..95ec2b78 100644 --- a/fastembed/text/pooled_normalized_embedding.py +++ b/fastembed/text/pooled_normalized_embedding.py @@ -1,6 +1,5 @@ from typing import Any, Iterable, Type -import numpy as np from fastembed.common.types import NumpyArray from fastembed.common.onnx_model import OnnxOutputContext @@ -144,7 +143,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Numpy embeddings = output.model_output attn_mask = output.attention_mask - return normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32) + return normalize(self.mean_pooling(embeddings, attn_mask)) class PooledNormalizedEmbeddingWorker(OnnxTextEmbeddingWorker): diff --git a/pyproject.toml b/pyproject.toml index f4af5889..4effb948 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "fastembed" -version = "0.5.1" +version = "0.6.0" description = "Fast, light, accurate library built for retrieval embedding generation" authors = ["Qdrant Team ", "NirantK "] license = "Apache License" diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index f893c57d..b64d1a79 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -71,8 +71,8 @@ def test_mock_add_custom_models(): source = ModelSource(hf="artificial") num_tokens = 10 - dummy_pooled_embedding = np.random.random((1, dim)).astype(np.float32) - dummy_token_embedding = np.random.random((1, num_tokens, dim)).astype(np.float32) + dummy_pooled_embedding = np.random.random((1, dim)) + dummy_token_embedding = np.random.random((1, num_tokens, dim)) dummy_attention_mask = np.ones((1, num_tokens)).astype(np.int64) dummy_token_output = OnnxOutputContext( @@ -91,15 +91,11 @@ def test_mock_add_custom_models(): expected_output = { f"{PoolingType.MEAN.lower()}-normalized": normalize( mean_pooling(dummy_token_embedding, dummy_attention_mask) - ).astype(np.float32), - f"{PoolingType.MEAN.lower()}": mean_pooling(dummy_token_embedding, dummy_attention_mask), - f"{PoolingType.CLS.lower()}-normalized": normalize(dummy_token_embedding[:, 0]).astype( - np.float32 ), + f"{PoolingType.MEAN.lower()}": mean_pooling(dummy_token_embedding, dummy_attention_mask), + f"{PoolingType.CLS.lower()}-normalized": normalize(dummy_token_embedding[:, 0]), f"{PoolingType.CLS.lower()}": dummy_token_embedding[:, 0], - f"{PoolingType.DISABLED.lower()}-normalized": normalize(dummy_pooled_embedding).astype( - np.float32 - ), + f"{PoolingType.DISABLED.lower()}-normalized": normalize(dummy_pooled_embedding), f"{PoolingType.DISABLED.lower()}": dummy_pooled_embedding, }