diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index bed3ca1e..d272c37a 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -93,18 +93,6 @@ ), model_file="model_optimized.onnx", ), - DenseModelDescription( - model="thenlper/gte-large", - dim=1024, - description=( - "Text embeddings, Unimodal (text), English, 512 input tokens truncation, " - "Prefixes for queries/documents: not necessary, 2023 year." - ), - license="mit", - size_in_GB=1.20, - sources=ModelSource(hf="qdrant/gte-large-onnx"), - model_file="model.onnx", - ), DenseModelDescription( model="mixedbread-ai/mxbai-embed-large-v1", dim=1024, @@ -314,6 +302,7 @@ def _preprocess_onnx_input( def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[NumpyArray]: embeddings = output.model_output + if embeddings.ndim == 3: # (batch_size, seq_len, embedding_dim) processed_embeddings = embeddings[:, 0] elif embeddings.ndim == 2: # (batch_size, embedding_dim) diff --git a/fastembed/text/pooled_normalized_embedding.py b/fastembed/text/pooled_normalized_embedding.py index f0b58b64..ed825eca 100644 --- a/fastembed/text/pooled_normalized_embedding.py +++ b/fastembed/text/pooled_normalized_embedding.py @@ -109,6 +109,18 @@ sources=ModelSource(hf="thenlper/gte-base"), model_file="onnx/model.onnx", ), + DenseModelDescription( + model="thenlper/gte-large", + dim=1024, + description=( + "Text embeddings, Unimodal (text), English, 512 input tokens truncation, " + "Prefixes for queries/documents: not necessary, 2023 year." + ), + license="mit", + size_in_GB=1.20, + sources=ModelSource(hf="qdrant/gte-large-onnx"), + model_file="model.onnx", + ), ] diff --git a/fastembed/text/text_embedding.py b/fastembed/text/text_embedding.py index c94a5883..b9c049cd 100644 --- a/fastembed/text/text_embedding.py +++ b/fastembed/text/text_embedding.py @@ -90,29 +90,24 @@ def __init__( super().__init__(model_name, cache_dir, threads, **kwargs) if model_name == "nomic-ai/nomic-embed-text-v1.5-Q": warnings.warn( - "The model 'nomic-ai/nomic-embed-text-v1.5-Q' has been updated on HuggingFace. " - "Please review the latest documentation and release notes to ensure compatibility with your workflow. ", - UserWarning, - stacklevel=2, - ) - if model_name == "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2": - warnings.warn( - "The model 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2' has been updated to " - "include a mean pooling layer. Please ensure your usage aligns with the new functionality. " - "Support for the previous version without mean pooling will be removed as of version 0.5.2.", + "The model 'nomic-ai/nomic-embed-text-v1.5-Q' has been updated on HuggingFace. Please review " + "the latest documentation on HF and release notes to ensure compatibility with your workflow. ", UserWarning, stacklevel=2, ) if model_name in { - "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", + "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", + "thenlper/gte-large", "intfloat/multilingual-e5-large", + "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", }: warnings.warn( - f"{model_name} has been updated as of fastembed 0.5.2, outputs are now average pooled.", + f"The model {model_name} now uses mean pooling instead of CLS embedding. " + f"In order to preserve the previous behaviour, consider either pinning fastembed version to 0.5.1 or " + "using `add_custom_model` functionality.", UserWarning, stacklevel=2, ) - for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY: supported_models = EMBEDDING_MODEL_TYPE._list_supported_models() if any(model_name.lower() == model.model.lower() for model in supported_models): diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index e4bfa0bf..cf39d7d1 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -52,7 +52,7 @@ [0.0802303, 0.3700881, -4.3053818, 0.4431803, -0.271572] ), "thenlper/gte-large": np.array( - [-0.01920587, 0.00113156, -0.00708992, -0.00632304, -0.04025577] + [-0.00986551, -0.00018734, 0.00605892, -0.03289612, -0.0387564], ), "mixedbread-ai/mxbai-embed-large-v1": np.array( [0.02295546, 0.03196154, 0.016512, -0.04031524, -0.0219634]