From d2611d64dd7855875ccd25545fe1952d714f90c8 Mon Sep 17 00:00:00 2001 From: Daniel Boros Date: Tue, 21 Oct 2025 17:16:13 +0200 Subject: [PATCH 01/15] feat: add public tokenize method for TextEmbedding class --- fastembed/text/text_embedding.py | 16 ++++++++++++++++ fastembed/text/text_embedding_base.py | 4 ++++ 2 files changed, 20 insertions(+) diff --git a/fastembed/text/text_embedding.py b/fastembed/text/text_embedding.py index 117f5af7..48de1844 100644 --- a/fastembed/text/text_embedding.py +++ b/fastembed/text/text_embedding.py @@ -1,5 +1,6 @@ import warnings from typing import Any, Iterable, Optional, Sequence, Type, Union +from tokenizers import Encoding from dataclasses import asdict from fastembed.common.types import NumpyArray, OnnxProvider @@ -162,6 +163,21 @@ def get_embedding_size(cls, model_name: str) -> int: ) return embedding_size + def tokenize(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> list[Encoding]: + """ + Tokenize input texts using the model's tokenizer. + + Args: + texts: String or list of strings to tokenize + **kwargs: Additional arguments passed to the tokenizer + + Returns: + List of tokenizer Encodings + """ + if isinstance(texts, str): + texts = [texts] + return self.model.tokenize(texts, **kwargs) + def embed( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/text/text_embedding_base.py b/fastembed/text/text_embedding_base.py index 75df9ac5..7f52979b 100644 --- a/fastembed/text/text_embedding_base.py +++ b/fastembed/text/text_embedding_base.py @@ -1,4 +1,5 @@ from typing import Iterable, Optional, Union, Any +from tokenizers import Encoding from fastembed.common.model_description import DenseModelDescription from fastembed.common.types import NumpyArray @@ -19,6 +20,9 @@ def __init__( self._local_files_only = kwargs.pop("local_files_only", False) self._embedding_size: Optional[int] = None + def tokenize(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> list[Encoding]: + raise NotImplementedError() + def embed( self, documents: Union[str, Iterable[str]], From d7abe5fea8279b367a5277c37e73f7cb8d292de5 Mon Sep 17 00:00:00 2001 From: Daniel Boros Date: Wed, 22 Oct 2025 11:43:08 +0200 Subject: [PATCH 02/15] feat: add tokenize method and tests --- .../late_interaction_embedding_base.py | 5 +++++ .../late_interaction_text_embedding.py | 17 +++++++++++++++ .../late_interaction_multimodal_embedding.py | 17 +++++++++++++++ ...e_interaction_multimodal_embedding_base.py | 4 ++++ fastembed/sparse/sparse_embedding_base.py | 4 ++++ fastembed/sparse/sparse_text_embedding.py | 17 +++++++++++++++ fastembed/sparse/splade_pp.py | 17 +++++++++++++++ tests/test_late_interaction_embeddings.py | 21 +++++++++++++++++++ tests/test_late_interaction_multimodal.py | 20 ++++++++++++++++++ tests/test_sparse_embeddings.py | 21 +++++++++++++++++++ tests/test_text_onnx_embeddings.py | 21 +++++++++++++++++++ 11 files changed, 164 insertions(+) diff --git a/fastembed/late_interaction/late_interaction_embedding_base.py b/fastembed/late_interaction/late_interaction_embedding_base.py index f677ba98..f66144da 100644 --- a/fastembed/late_interaction/late_interaction_embedding_base.py +++ b/fastembed/late_interaction/late_interaction_embedding_base.py @@ -1,5 +1,7 @@ from typing import Iterable, Optional, Union, Any +from tokenizers import Encoding + from fastembed.common.model_description import DenseModelDescription from fastembed.common.types import NumpyArray from fastembed.common.model_management import ModelManagement @@ -19,6 +21,9 @@ def __init__( self._local_files_only = kwargs.pop("local_files_only", False) self._embedding_size: Optional[int] = None + def tokenize(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> list[Encoding]: + raise NotImplementedError() + def embed( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/late_interaction/late_interaction_text_embedding.py b/fastembed/late_interaction/late_interaction_text_embedding.py index 22833618..ed541ff6 100644 --- a/fastembed/late_interaction/late_interaction_text_embedding.py +++ b/fastembed/late_interaction/late_interaction_text_embedding.py @@ -1,6 +1,8 @@ from typing import Any, Iterable, Optional, Sequence, Type, Union from dataclasses import asdict +from tokenizers import Encoding + from fastembed.common.model_description import DenseModelDescription from fastembed.common.types import NumpyArray from fastembed.common import OnnxProvider @@ -114,6 +116,21 @@ def get_embedding_size(cls, model_name: str) -> int: ) return embedding_size + def tokenize(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> list[Encoding]: + """ + Tokenize input texts using the model's tokenizer. + + Args: + texts: String or list of strings to tokenize + **kwargs: Additional arguments passed to the tokenizer + + Returns: + List of tokenizer Encodings + """ + if isinstance(texts, str): + texts = [texts] + return self.model.tokenize(texts, **kwargs) + def embed( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py index 39c1763e..f2c8190a 100644 --- a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py @@ -1,6 +1,8 @@ from typing import Any, Iterable, Optional, Sequence, Type, Union from dataclasses import asdict +from tokenizers import Encoding + from fastembed.common import OnnxProvider, ImageInput from fastembed.common.types import NumpyArray from fastembed.late_interaction_multimodal.colpali import ColPali @@ -117,6 +119,21 @@ def get_embedding_size(cls, model_name: str) -> int: ) return embedding_size + def tokenize(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> list[Encoding]: + """ + Tokenize input texts using the model's tokenizer. + + Args: + texts: String or list of strings to tokenize + **kwargs: Additional arguments passed to the tokenizer + + Returns: + List of tokenizer Encodings + """ + if isinstance(texts, str): + texts = [texts] + return self.model.tokenize(texts, **kwargs) + def embed_text( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py index 12e3553c..3767ad12 100644 --- a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py @@ -1,5 +1,6 @@ from typing import Iterable, Optional, Union, Any +from tokenizers import Encoding from fastembed.common import ImageInput from fastembed.common.model_description import DenseModelDescription @@ -21,6 +22,9 @@ def __init__( self._local_files_only = kwargs.pop("local_files_only", False) self._embedding_size: Optional[int] = None + def tokenize(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> list[Encoding]: + raise NotImplementedError() + def embed_text( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/sparse/sparse_embedding_base.py b/fastembed/sparse/sparse_embedding_base.py index b153c814..ce6ace70 100644 --- a/fastembed/sparse/sparse_embedding_base.py +++ b/fastembed/sparse/sparse_embedding_base.py @@ -3,6 +3,7 @@ import numpy as np from numpy.typing import NDArray +from tokenizers import Encoding from fastembed.common.model_description import SparseModelDescription from fastembed.common.types import NumpyArray @@ -44,6 +45,9 @@ def __init__( self.threads = threads self._local_files_only = kwargs.pop("local_files_only", False) + def tokenize(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> list[Encoding]: + raise NotImplementedError() + def embed( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/sparse/sparse_text_embedding.py b/fastembed/sparse/sparse_text_embedding.py index 3cb14c3e..255eb1a9 100644 --- a/fastembed/sparse/sparse_text_embedding.py +++ b/fastembed/sparse/sparse_text_embedding.py @@ -1,6 +1,8 @@ from typing import Any, Iterable, Optional, Sequence, Type, Union from dataclasses import asdict +from tokenizers import Encoding + from fastembed.common import OnnxProvider from fastembed.sparse.bm25 import Bm25 from fastembed.sparse.bm42 import Bm42 @@ -91,6 +93,21 @@ def __init__( "Please check the supported models using `SparseTextEmbedding.list_supported_models()`" ) + def tokenize(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> list[Encoding]: + """ + Tokenize input texts using the model's tokenizer. + + Args: + texts: String or list of strings to tokenize + **kwargs: Additional arguments passed to the tokenizer + + Returns: + List of tokenizer Encodings + """ + if isinstance(texts, str): + texts = [texts] + return self.model.tokenize(texts, **kwargs) + def embed( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/sparse/splade_pp.py b/fastembed/sparse/splade_pp.py index d2c4af38..662d55d1 100644 --- a/fastembed/sparse/splade_pp.py +++ b/fastembed/sparse/splade_pp.py @@ -1,6 +1,8 @@ from typing import Any, Iterable, Optional, Sequence, Type, Union import numpy as np +from tokenizers import Encoding + from fastembed.common import OnnxProvider from fastembed.common.onnx_model import OnnxOutputContext from fastembed.common.utils import define_cache_dir @@ -135,6 +137,21 @@ def load_onnx_model(self) -> None: device_id=self.device_id, ) + def tokenize(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> list[Encoding]: + """ + Tokenize input texts using the model's tokenizer. + + Args: + texts: String or list of strings to tokenize + **kwargs: Additional arguments passed to the tokenizer + + Returns: + List of tokenizer Encodings + """ + if isinstance(texts, str): + texts = [texts] + return self.tokenizer.encode_batch(texts) + def embed( self, documents: Union[str, Iterable[str]], diff --git a/tests/test_late_interaction_embeddings.py b/tests/test_late_interaction_embeddings.py index f89882f4..ce806365 100644 --- a/tests/test_late_interaction_embeddings.py +++ b/tests/test_late_interaction_embeddings.py @@ -308,3 +308,24 @@ def test_embedding_size(): assert model.embedding_size == 96 if is_ci: delete_model_cache(model.model._model_dir) + + +@pytest.mark.parametrize("model_name", ["colbert-ir/colbertv2.0"]) +def test_tokenize(model_name: str) -> None: + is_ci = os.getenv("CI") + model = LateInteractionTextEmbedding(model_name=model_name) + + texts = ["hello world", "flag embedding"] + encodings = model.tokenize(texts, is_doc=True) + assert len(encodings) == 2 + for encoding in encodings: + assert encoding.ids is not None + assert len(encoding.ids) > 0 + + encodings = model.tokenize(["hello world"], is_doc=False) + assert len(encodings) == 1 + assert encodings[0].ids is not None + assert len(encodings[0].ids) > 0 + + if is_ci: + delete_model_cache(model.model._model_dir) diff --git a/tests/test_late_interaction_multimodal.py b/tests/test_late_interaction_multimodal.py index 80135f3b..2a1bf631 100644 --- a/tests/test_late_interaction_multimodal.py +++ b/tests/test_late_interaction_multimodal.py @@ -101,3 +101,23 @@ def test_embedding_size(): model_name = "Qdrant/ColPali-v1.3-fp16" model = LateInteractionMultimodalEmbedding(model_name=model_name, lazy_load=True) assert model.embedding_size == 128 + + +@pytest.mark.parametrize("model_name", ["Qdrant/colpali-v1.3-fp16"]) +def test_tokenize(model_name: str) -> None: + if os.getenv("CI"): + pytest.skip("Colpali is too large to test in CI") + + model = LateInteractionMultimodalEmbedding(model_name=model_name) + + encodings = model.tokenize(["hello world"]) + assert len(encodings) == 1 + assert encodings[0].ids is not None + assert len(encodings[0].ids) > 0 + + texts = ["hello world", "flag embedding"] + encodings = model.tokenize(texts) + assert len(encodings) == 2 + for encoding in encodings: + assert encoding.ids is not None + assert len(encoding.ids) > 0 diff --git a/tests/test_sparse_embeddings.py b/tests/test_sparse_embeddings.py index 90514967..58899b08 100644 --- a/tests/test_sparse_embeddings.py +++ b/tests/test_sparse_embeddings.py @@ -276,3 +276,24 @@ def test_lazy_load(model_name: str) -> None: if is_ci: delete_model_cache(model.model._model_dir) + + +@pytest.mark.parametrize("model_name", ["prithivida/Splade_PP_en_v1"]) +def test_tokenize(model_name: str) -> None: + is_ci = os.getenv("CI") + model = SparseTextEmbedding(model_name=model_name) + + encodings = model.tokenize("hello world") + assert len(encodings) == 1 + assert encodings[0].ids is not None + assert len(encodings[0].ids) > 0 + + texts = ["hello world", "flag embedding"] + encodings = model.tokenize(texts) + assert len(encodings) == 2 + for encoding in encodings: + assert encoding.ids is not None + assert len(encoding.ids) > 0 + + if is_ci: + delete_model_cache(model.model._model_dir) diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index 46ce6554..e1722042 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -193,3 +193,24 @@ def test_embedding_size() -> None: if is_ci: delete_model_cache(model.model._model_dir) + + +@pytest.mark.parametrize("model_name", ["BAAI/bge-small-en-v1.5"]) +def test_tokenize(model_name: str) -> None: + is_ci = os.getenv("CI") + model = TextEmbedding(model_name=model_name) + + encodings = model.tokenize("hello world") + assert len(encodings) == 1 + assert encodings[0].ids is not None + assert len(encodings[0].ids) > 0 + + texts = ["hello world", "flag embedding"] + encodings = model.tokenize(texts) + assert len(encodings) == 2 + for encoding in encodings: + assert encoding.ids is not None + assert len(encoding.ids) > 0 + + if is_ci: + delete_model_cache(model.model._model_dir) From 4f0ad849485fc7d373c3aa318aa2639b6c043667 Mon Sep 17 00:00:00 2001 From: Daniel Boros Date: Wed, 22 Oct 2025 11:51:26 +0200 Subject: [PATCH 03/15] fix: test errors --- fastembed/late_interaction/colbert.py | 2 +- fastembed/late_interaction_multimodal/colpali.py | 5 +++-- fastembed/sparse/splade_pp.py | 5 ++++- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/fastembed/late_interaction/colbert.py b/fastembed/late_interaction/colbert.py index 3aa105ae..71be00af 100644 --- a/fastembed/late_interaction/colbert.py +++ b/fastembed/late_interaction/colbert.py @@ -80,7 +80,7 @@ def _preprocess_onnx_input( ) return onnx_input - def tokenize(self, documents: list[str], is_doc: bool = True, **kwargs: Any) -> list[Encoding]: + def tokenize(self, documents: list[str], is_doc: bool = True, **kwargs: Any) -> list[Encoding]: # type: ignore[override] return ( self._tokenize_documents(documents=documents) if is_doc diff --git a/fastembed/late_interaction_multimodal/colpali.py b/fastembed/late_interaction_multimodal/colpali.py index 9fab9535..39bfd389 100644 --- a/fastembed/late_interaction_multimodal/colpali.py +++ b/fastembed/late_interaction_multimodal/colpali.py @@ -160,14 +160,15 @@ def _post_process_onnx_text_output( """ return output.model_output - def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: + def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: # type: ignore[override] texts_query: list[str] = [] for query in documents: query = self.BOS_TOKEN + self.QUERY_PREFIX + query + self.PAD_TOKEN * 10 query += "\n" texts_query.append(query) - encoded = self.tokenizer.encode_batch(texts_query) # type: ignore[union-attr] + assert self.tokenizer is not None + encoded = self.tokenizer.encode_batch(texts_query) return encoded def _preprocess_onnx_text_input( diff --git a/fastembed/sparse/splade_pp.py b/fastembed/sparse/splade_pp.py index 662d55d1..ce90ddd8 100644 --- a/fastembed/sparse/splade_pp.py +++ b/fastembed/sparse/splade_pp.py @@ -137,7 +137,7 @@ def load_onnx_model(self) -> None: device_id=self.device_id, ) - def tokenize(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> list[Encoding]: + def tokenize(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> list[Encoding]: # type: ignore[override] """ Tokenize input texts using the model's tokenizer. @@ -150,6 +150,9 @@ def tokenize(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> list[Enco """ if isinstance(texts, str): texts = [texts] + if not isinstance(texts, list): + texts = list(texts) + assert self.tokenizer is not None return self.tokenizer.encode_batch(texts) def embed( From cc431f3890732b78af913a8b82be0b36117f45f0 Mon Sep 17 00:00:00 2001 From: Daniel Boros Date: Wed, 22 Oct 2025 12:11:09 +0200 Subject: [PATCH 04/15] fix: type override --- fastembed/late_interaction/token_embeddings.py | 2 +- fastembed/sparse/bm42.py | 2 +- fastembed/sparse/minicoil.py | 2 +- fastembed/text/onnx_embedding.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/fastembed/late_interaction/token_embeddings.py b/fastembed/late_interaction/token_embeddings.py index ec4844ba..8a00a146 100644 --- a/fastembed/late_interaction/token_embeddings.py +++ b/fastembed/late_interaction/token_embeddings.py @@ -25,7 +25,7 @@ ] -class TokenEmbeddingsModel(OnnxTextEmbedding, LateInteractionTextEmbeddingBase): +class TokenEmbeddingsModel(OnnxTextEmbedding, LateInteractionTextEmbeddingBase): # type: ignore[misc] @classmethod def _list_supported_models(cls) -> list[DenseModelDescription]: """Lists the supported models. diff --git a/fastembed/sparse/bm42.py b/fastembed/sparse/bm42.py index 3e51404f..bd2dc2c8 100644 --- a/fastembed/sparse/bm42.py +++ b/fastembed/sparse/bm42.py @@ -44,7 +44,7 @@ def get_language_by_model_name(model_name: str) -> str: return MODEL_TO_LANGUAGE[model_name.lower()] -class Bm42(SparseTextEmbeddingBase, OnnxTextModel[SparseEmbedding]): +class Bm42(SparseTextEmbeddingBase, OnnxTextModel[SparseEmbedding]): # type: ignore[misc] """ Bm42 is an extension of BM25, which tries to better evaluate importance of tokens in the documents, by extracting attention weights from the transformer model. diff --git a/fastembed/sparse/minicoil.py b/fastembed/sparse/minicoil.py index efaa9abb..e5cf724e 100644 --- a/fastembed/sparse/minicoil.py +++ b/fastembed/sparse/minicoil.py @@ -58,7 +58,7 @@ def get_language_by_model_name(model_name: str) -> str: return MODEL_TO_LANGUAGE[model_name.lower()] -class MiniCOIL(SparseTextEmbeddingBase, OnnxTextModel[SparseEmbedding]): +class MiniCOIL(SparseTextEmbeddingBase, OnnxTextModel[SparseEmbedding]): # type: ignore[misc] """ MiniCOIL is a sparse embedding model, that resolves semantic meaning of the words, while keeping exact keyword match behavior. diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index 4cc892f5..4f3a8140 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -183,7 +183,7 @@ ] -class OnnxTextEmbedding(TextEmbeddingBase, OnnxTextModel[NumpyArray]): +class OnnxTextEmbedding(TextEmbeddingBase, OnnxTextModel[NumpyArray]): # type: ignore[misc] """Implementation of the Flag Embedding model.""" @classmethod From c4f2297933a6fd5bfd04214bc4eef82d078974a4 Mon Sep 17 00:00:00 2001 From: Daniel Boros Date: Wed, 22 Oct 2025 12:16:03 +0200 Subject: [PATCH 05/15] fix: tokenization test --- fastembed/text/onnx_embedding.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index 4f3a8140..368dfc6c 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -1,5 +1,7 @@ from typing import Any, Iterable, Optional, Sequence, Type, Union +from tokenizers import Encoding + from fastembed.common.types import NumpyArray, OnnxProvider from fastembed.common.onnx_model import OnnxOutputContext from fastembed.common.utils import define_cache_dir, normalize @@ -319,6 +321,20 @@ def _post_process_onnx_output( raise ValueError(f"Unsupported embedding shape: {embeddings.shape}") return normalize(processed_embeddings) + def tokenize(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> list[Encoding]: # type: ignore[override] + """Tokenize the input texts. + + Args: + texts: A single string or an iterable of strings to tokenize. + **kwargs: Additional keyword arguments. + + Returns: + list[Encoding]: List of tokenized encodings. + """ + if isinstance(texts, str): + texts = [texts] + return OnnxTextModel.tokenize(self, list(texts), **kwargs) + def load_onnx_model(self) -> None: self._load_onnx_model( model_dir=self._model_dir, From 97b862cb7abd9336811c7d8eb0ff8af8764a40f3 Mon Sep 17 00:00:00 2001 From: Daniel Boros Date: Wed, 22 Oct 2025 13:18:36 +0200 Subject: [PATCH 06/15] chore: list[str] instead of union --- fastembed/late_interaction/colbert.py | 6 +++--- .../late_interaction_embedding_base.py | 2 +- .../late_interaction_text_embedding.py | 6 ++---- fastembed/late_interaction_multimodal/colpali.py | 4 ++-- .../late_interaction_multimodal_embedding.py | 6 ++---- .../late_interaction_multimodal_embedding_base.py | 2 +- .../onnx_multimodal_model.py | 4 ++-- fastembed/sparse/sparse_embedding_base.py | 2 +- fastembed/sparse/sparse_text_embedding.py | 6 ++---- fastembed/sparse/splade_pp.py | 13 +++++-------- fastembed/text/onnx_embedding.py | 8 +++----- fastembed/text/onnx_text_model.py | 4 ++-- fastembed/text/text_embedding.py | 6 ++---- fastembed/text/text_embedding_base.py | 2 +- tests/test_sparse_embeddings.py | 2 +- tests/test_text_onnx_embeddings.py | 2 +- 16 files changed, 31 insertions(+), 44 deletions(-) diff --git a/fastembed/late_interaction/colbert.py b/fastembed/late_interaction/colbert.py index 71be00af..327d1429 100644 --- a/fastembed/late_interaction/colbert.py +++ b/fastembed/late_interaction/colbert.py @@ -80,11 +80,11 @@ def _preprocess_onnx_input( ) return onnx_input - def tokenize(self, documents: list[str], is_doc: bool = True, **kwargs: Any) -> list[Encoding]: # type: ignore[override] + def tokenize(self, texts: list[str], is_doc: bool = True, **kwargs: Any) -> list[Encoding]: # type: ignore[override] return ( - self._tokenize_documents(documents=documents) + self._tokenize_documents(documents=texts) if is_doc - else self._tokenize_query(query=next(iter(documents))) + else self._tokenize_query(query=next(iter(texts))) ) def _tokenize_query(self, query: str) -> list[Encoding]: diff --git a/fastembed/late_interaction/late_interaction_embedding_base.py b/fastembed/late_interaction/late_interaction_embedding_base.py index f66144da..e6f00d8c 100644 --- a/fastembed/late_interaction/late_interaction_embedding_base.py +++ b/fastembed/late_interaction/late_interaction_embedding_base.py @@ -21,7 +21,7 @@ def __init__( self._local_files_only = kwargs.pop("local_files_only", False) self._embedding_size: Optional[int] = None - def tokenize(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> list[Encoding]: + def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: raise NotImplementedError() def embed( diff --git a/fastembed/late_interaction/late_interaction_text_embedding.py b/fastembed/late_interaction/late_interaction_text_embedding.py index ed541ff6..e52e8154 100644 --- a/fastembed/late_interaction/late_interaction_text_embedding.py +++ b/fastembed/late_interaction/late_interaction_text_embedding.py @@ -116,19 +116,17 @@ def get_embedding_size(cls, model_name: str) -> int: ) return embedding_size - def tokenize(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> list[Encoding]: + def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: """ Tokenize input texts using the model's tokenizer. Args: - texts: String or list of strings to tokenize + texts: List of strings to tokenize **kwargs: Additional arguments passed to the tokenizer Returns: List of tokenizer Encodings """ - if isinstance(texts, str): - texts = [texts] return self.model.tokenize(texts, **kwargs) def embed( diff --git a/fastembed/late_interaction_multimodal/colpali.py b/fastembed/late_interaction_multimodal/colpali.py index 39bfd389..7242aaad 100644 --- a/fastembed/late_interaction_multimodal/colpali.py +++ b/fastembed/late_interaction_multimodal/colpali.py @@ -160,9 +160,9 @@ def _post_process_onnx_text_output( """ return output.model_output - def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: # type: ignore[override] + def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: # type: ignore[override] texts_query: list[str] = [] - for query in documents: + for query in texts: query = self.BOS_TOKEN + self.QUERY_PREFIX + query + self.PAD_TOKEN * 10 query += "\n" diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py index f2c8190a..cc534909 100644 --- a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py @@ -119,19 +119,17 @@ def get_embedding_size(cls, model_name: str) -> int: ) return embedding_size - def tokenize(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> list[Encoding]: + def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: """ Tokenize input texts using the model's tokenizer. Args: - texts: String or list of strings to tokenize + texts: List of strings to tokenize **kwargs: Additional arguments passed to the tokenizer Returns: List of tokenizer Encodings """ - if isinstance(texts, str): - texts = [texts] return self.model.tokenize(texts, **kwargs) def embed_text( diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py index 3767ad12..be39e25c 100644 --- a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py @@ -22,7 +22,7 @@ def __init__( self._local_files_only = kwargs.pop("local_files_only", False) self._embedding_size: Optional[int] = None - def tokenize(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> list[Encoding]: + def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: raise NotImplementedError() def embed_text( diff --git a/fastembed/late_interaction_multimodal/onnx_multimodal_model.py b/fastembed/late_interaction_multimodal/onnx_multimodal_model.py index 83706a2b..d29c5016 100644 --- a/fastembed/late_interaction_multimodal/onnx_multimodal_model.py +++ b/fastembed/late_interaction_multimodal/onnx_multimodal_model.py @@ -80,8 +80,8 @@ def _load_onnx_model( def load_onnx_model(self) -> None: raise NotImplementedError("Subclasses must implement this method") - def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: - return self.tokenizer.encode_batch(documents) # type: ignore[union-attr] + def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: + return self.tokenizer.encode_batch(texts) def onnx_embed_text( self, diff --git a/fastembed/sparse/sparse_embedding_base.py b/fastembed/sparse/sparse_embedding_base.py index ce6ace70..1cdcb2c2 100644 --- a/fastembed/sparse/sparse_embedding_base.py +++ b/fastembed/sparse/sparse_embedding_base.py @@ -45,7 +45,7 @@ def __init__( self.threads = threads self._local_files_only = kwargs.pop("local_files_only", False) - def tokenize(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> list[Encoding]: + def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: raise NotImplementedError() def embed( diff --git a/fastembed/sparse/sparse_text_embedding.py b/fastembed/sparse/sparse_text_embedding.py index 255eb1a9..16a1ca06 100644 --- a/fastembed/sparse/sparse_text_embedding.py +++ b/fastembed/sparse/sparse_text_embedding.py @@ -93,19 +93,17 @@ def __init__( "Please check the supported models using `SparseTextEmbedding.list_supported_models()`" ) - def tokenize(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> list[Encoding]: + def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: """ Tokenize input texts using the model's tokenizer. Args: - texts: String or list of strings to tokenize + texts: List of strings to tokenize **kwargs: Additional arguments passed to the tokenizer Returns: List of tokenizer Encodings """ - if isinstance(texts, str): - texts = [texts] return self.model.tokenize(texts, **kwargs) def embed( diff --git a/fastembed/sparse/splade_pp.py b/fastembed/sparse/splade_pp.py index ce90ddd8..240e96a6 100644 --- a/fastembed/sparse/splade_pp.py +++ b/fastembed/sparse/splade_pp.py @@ -137,23 +137,20 @@ def load_onnx_model(self) -> None: device_id=self.device_id, ) - def tokenize(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> list[Encoding]: # type: ignore[override] + def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: # type: ignore[override] """ Tokenize input texts using the model's tokenizer. Args: - texts: String or list of strings to tokenize + texts: List of strings to tokenize **kwargs: Additional arguments passed to the tokenizer Returns: List of tokenizer Encodings """ - if isinstance(texts, str): - texts = [texts] - if not isinstance(texts, list): - texts = list(texts) - assert self.tokenizer is not None - return self.tokenizer.encode_batch(texts) + if self.tokenizer is None: + raise RuntimeError("Tokenizer not initialized") + return self.tokenizer.encode_batch(texts, **kwargs) def embed( self, diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index 368dfc6c..24d142be 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -321,19 +321,17 @@ def _post_process_onnx_output( raise ValueError(f"Unsupported embedding shape: {embeddings.shape}") return normalize(processed_embeddings) - def tokenize(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> list[Encoding]: # type: ignore[override] + def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: """Tokenize the input texts. Args: - texts: A single string or an iterable of strings to tokenize. + texts: A list of strings to tokenize. **kwargs: Additional keyword arguments. Returns: list[Encoding]: List of tokenized encodings. """ - if isinstance(texts, str): - texts = [texts] - return OnnxTextModel.tokenize(self, list(texts), **kwargs) + return OnnxTextModel.tokenize(self, texts, **kwargs) def load_onnx_model(self) -> None: self._load_onnx_model( diff --git a/fastembed/text/onnx_text_model.py b/fastembed/text/onnx_text_model.py index c939b21d..5a6e0a6f 100644 --- a/fastembed/text/onnx_text_model.py +++ b/fastembed/text/onnx_text_model.py @@ -68,8 +68,8 @@ def _load_onnx_model( def load_onnx_model(self) -> None: raise NotImplementedError("Subclasses must implement this method") - def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: - return self.tokenizer.encode_batch(documents) # type: ignore[union-attr] + def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: + return self.tokenizer.encode_batch(texts) # type: ignore[union-attr] def onnx_embed( self, diff --git a/fastembed/text/text_embedding.py b/fastembed/text/text_embedding.py index 48de1844..2e56e4af 100644 --- a/fastembed/text/text_embedding.py +++ b/fastembed/text/text_embedding.py @@ -163,19 +163,17 @@ def get_embedding_size(cls, model_name: str) -> int: ) return embedding_size - def tokenize(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> list[Encoding]: + def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: """ Tokenize input texts using the model's tokenizer. Args: - texts: String or list of strings to tokenize + texts: List of strings to tokenize **kwargs: Additional arguments passed to the tokenizer Returns: List of tokenizer Encodings """ - if isinstance(texts, str): - texts = [texts] return self.model.tokenize(texts, **kwargs) def embed( diff --git a/fastembed/text/text_embedding_base.py b/fastembed/text/text_embedding_base.py index 7f52979b..f9dabf90 100644 --- a/fastembed/text/text_embedding_base.py +++ b/fastembed/text/text_embedding_base.py @@ -20,7 +20,7 @@ def __init__( self._local_files_only = kwargs.pop("local_files_only", False) self._embedding_size: Optional[int] = None - def tokenize(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> list[Encoding]: + def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: raise NotImplementedError() def embed( diff --git a/tests/test_sparse_embeddings.py b/tests/test_sparse_embeddings.py index 58899b08..16395c60 100644 --- a/tests/test_sparse_embeddings.py +++ b/tests/test_sparse_embeddings.py @@ -283,7 +283,7 @@ def test_tokenize(model_name: str) -> None: is_ci = os.getenv("CI") model = SparseTextEmbedding(model_name=model_name) - encodings = model.tokenize("hello world") + encodings = model.tokenize(["hello world"]) assert len(encodings) == 1 assert encodings[0].ids is not None assert len(encodings[0].ids) > 0 diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index e1722042..65f01af0 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -200,7 +200,7 @@ def test_tokenize(model_name: str) -> None: is_ci = os.getenv("CI") model = TextEmbedding(model_name=model_name) - encodings = model.tokenize("hello world") + encodings = model.tokenize(["hello world"]) assert len(encodings) == 1 assert encodings[0].ids is not None assert len(encodings[0].ids) > 0 From 4045721c292999cc70a908de618d16c500fcab22 Mon Sep 17 00:00:00 2001 From: Daniel Boros Date: Wed, 22 Oct 2025 13:28:20 +0200 Subject: [PATCH 07/15] fix: tpye issue --- fastembed/late_interaction_multimodal/onnx_multimodal_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastembed/late_interaction_multimodal/onnx_multimodal_model.py b/fastembed/late_interaction_multimodal/onnx_multimodal_model.py index d29c5016..3a22fc46 100644 --- a/fastembed/late_interaction_multimodal/onnx_multimodal_model.py +++ b/fastembed/late_interaction_multimodal/onnx_multimodal_model.py @@ -81,7 +81,7 @@ def load_onnx_model(self) -> None: raise NotImplementedError("Subclasses must implement this method") def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: - return self.tokenizer.encode_batch(texts) + return self.tokenizer.encode_batch(texts) # type: ignore[union-attr] def onnx_embed_text( self, From cecaf296ce89ad357c59fc64a55c595ddb82cb00 Mon Sep 17 00:00:00 2001 From: Daniel Boros Date: Wed, 22 Oct 2025 14:19:33 +0200 Subject: [PATCH 08/15] fix: tests --- fastembed/sparse/bm25.py | 21 +++++++++++ fastembed/sparse/bm42.py | 4 +++ fastembed/sparse/minicoil.py | 5 ++- tests/test_sparse_embeddings.py | 62 +++++++++++++++++++++++++++++++++ 4 files changed, 91 insertions(+), 1 deletion(-) diff --git a/fastembed/sparse/bm25.py b/fastembed/sparse/bm25.py index b6ac59fd..4c92e478 100644 --- a/fastembed/sparse/bm25.py +++ b/fastembed/sparse/bm25.py @@ -7,6 +7,7 @@ import mmh3 import numpy as np from py_rust_stemmers import SnowballStemmer +from tokenizers import Encoding from fastembed.common.utils import ( define_cache_dir, iter_batch, @@ -136,6 +137,26 @@ def __init__( self.tokenizer = SimpleTokenizer + def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: + """Tokenize texts using SimpleTokenizer. + + Returns a list of simple Encoding-like objects with token strings. + Note: BM25 uses a simple word tokenizer, not a learned tokenizer. + """ + result = [] + for text in texts: + tokens = self.tokenizer.tokenize(text) + + # Create a simple object that mimics Encoding interface + class SimpleEncoding: + def __init__(self, tokens: list[str]): + self.tokens = tokens + self.ids = tokens # For BM25, tokens are the IDs + self.attention_mask = [1] * len(tokens) + + result.append(SimpleEncoding(tokens)) # type: ignore[arg-type] + return result # type: ignore[return-value] + @classmethod def _list_supported_models(cls) -> list[SparseModelDescription]: """Lists the supported models. diff --git a/fastembed/sparse/bm42.py b/fastembed/sparse/bm42.py index bd2dc2c8..e48c789c 100644 --- a/fastembed/sparse/bm42.py +++ b/fastembed/sparse/bm42.py @@ -16,6 +16,7 @@ ) from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker from fastembed.common.model_description import SparseModelDescription, ModelSource +from tokenizers import Encoding supported_bm42_models: list[SparseModelDescription] = [ SparseModelDescription( @@ -138,6 +139,9 @@ def __init__( if not self.lazy_load: self.load_onnx_model() + def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: + return OnnxTextModel.tokenize(self, list(texts), **kwargs) + def load_onnx_model(self) -> None: self._load_onnx_model( model_dir=self._model_dir, diff --git a/fastembed/sparse/minicoil.py b/fastembed/sparse/minicoil.py index e5cf724e..2a049a69 100644 --- a/fastembed/sparse/minicoil.py +++ b/fastembed/sparse/minicoil.py @@ -5,7 +5,7 @@ import numpy as np from numpy.typing import NDArray from py_rust_stemmers import SnowballStemmer -from tokenizers import Tokenizer +from tokenizers import Tokenizer, Encoding from fastembed.common.model_description import SparseModelDescription, ModelSource from fastembed.common.onnx_model import OnnxOutputContext @@ -145,6 +145,9 @@ def __init__( if not self.lazy_load: self.load_onnx_model() + def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: + return OnnxTextModel.tokenize(self, list(texts), **kwargs) + def load_onnx_model(self) -> None: self._load_onnx_model( model_dir=self._model_dir, diff --git a/tests/test_sparse_embeddings.py b/tests/test_sparse_embeddings.py index 16395c60..f8eaed00 100644 --- a/tests/test_sparse_embeddings.py +++ b/tests/test_sparse_embeddings.py @@ -5,6 +5,8 @@ import numpy as np from fastembed.sparse.bm25 import Bm25 +from fastembed.sparse.bm42 import Bm42 +from fastembed.sparse.minicoil import MiniCOIL from fastembed.sparse.sparse_text_embedding import SparseTextEmbedding from tests.utils import delete_model_cache, should_test_model @@ -297,3 +299,63 @@ def test_tokenize(model_name: str) -> None: if is_ci: delete_model_cache(model.model._model_dir) + + +def test_tokenize_bm25() -> None: + is_ci = os.getenv("CI") + model = Bm25("Qdrant/bm25", language="english") + + encodings = model.tokenize(["hello world"]) + assert len(encodings) == 1 + assert encodings[0].ids is not None + assert len(encodings[0].ids) > 0 + + texts = ["hello world", "flag embedding"] + encodings = model.tokenize(texts) + assert len(encodings) == 2 + for encoding in encodings: + assert encoding.ids is not None + assert len(encoding.ids) > 0 + + if is_ci: + delete_model_cache(model._model_dir) + + +def test_tokenize_bm42() -> None: + is_ci = os.getenv("CI") + model = Bm42("Qdrant/bm42-all-minilm-l6-v2-attentions") + + encodings = model.tokenize(["hello world"]) + assert len(encodings) == 1 + assert encodings[0].ids is not None + assert len(encodings[0].ids) > 0 + + texts = ["hello world", "flag embedding"] + encodings = model.tokenize(texts) + assert len(encodings) == 2 + for encoding in encodings: + assert encoding.ids is not None + assert len(encoding.ids) > 0 + + if is_ci: + delete_model_cache(model._model_dir) + + +def test_tokenize_minicoil() -> None: + is_ci = os.getenv("CI") + model = MiniCOIL("Qdrant/minicoil-v1") + + encodings = model.tokenize(["hello world"]) + assert len(encodings) == 1 + assert encodings[0].ids is not None + assert len(encodings[0].ids) > 0 + + texts = ["hello world", "flag embedding"] + encodings = model.tokenize(texts) + assert len(encodings) == 2 + for encoding in encodings: + assert encoding.ids is not None + assert len(encoding.ids) > 0 + + if is_ci: + delete_model_cache(model._model_dir) From 98f0fc8473de70494a7d21a4478735eced99b2f2 Mon Sep 17 00:00:00 2001 From: Daniel Boros Date: Wed, 22 Oct 2025 15:25:17 +0200 Subject: [PATCH 09/15] fix: ai recommendations --- .../late_interaction_multimodal/onnx_multimodal_model.py | 4 +++- fastembed/rerank/cross_encoder/onnx_text_model.py | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/fastembed/late_interaction_multimodal/onnx_multimodal_model.py b/fastembed/late_interaction_multimodal/onnx_multimodal_model.py index 3a22fc46..8b7a5ce0 100644 --- a/fastembed/late_interaction_multimodal/onnx_multimodal_model.py +++ b/fastembed/late_interaction_multimodal/onnx_multimodal_model.py @@ -81,7 +81,9 @@ def load_onnx_model(self) -> None: raise NotImplementedError("Subclasses must implement this method") def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: - return self.tokenizer.encode_batch(texts) # type: ignore[union-attr] + if self.tokenizer is None: + raise RuntimeError("Tokenizer not initialized") + return self.tokenizer.encode_batch(texts, **kwargs) # type: ignore[union-attr] def onnx_embed_text( self, diff --git a/fastembed/rerank/cross_encoder/onnx_text_model.py b/fastembed/rerank/cross_encoder/onnx_text_model.py index 3fc4e81c..cda7ddda 100644 --- a/fastembed/rerank/cross_encoder/onnx_text_model.py +++ b/fastembed/rerank/cross_encoder/onnx_text_model.py @@ -45,8 +45,10 @@ def _load_onnx_model( self.tokenizer, _ = load_tokenizer(model_dir=model_dir) assert self.tokenizer is not None - def tokenize(self, pairs: list[tuple[str, str]], **_: Any) -> list[Encoding]: - return self.tokenizer.encode_batch(pairs) # type: ignore[union-attr] + def tokenize(self, pairs: list[tuple[str, str]], **kwargs: Any) -> list[Encoding]: + if self.tokenizer is None: + raise RuntimeError("Tokenizer not initialized") + return self.tokenizer.encode_batch(pairs, **kwargs) # type: ignore[union-attr] def _build_onnx_input(self, tokenized_input: list[Encoding]) -> dict[str, NumpyArray]: input_names: set[str] = {node.name for node in self.model.get_inputs()} # type: ignore[union-attr] From 39b0d7fbe4bdf0ffa1f0275ae00b6d281ac3011e Mon Sep 17 00:00:00 2001 From: dancixx Date: Thu, 30 Oct 2025 10:24:38 +0100 Subject: [PATCH 10/15] fix: pr reivews --- tests/test_late_interaction_embeddings.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/tests/test_late_interaction_embeddings.py b/tests/test_late_interaction_embeddings.py index ce806365..c037bef9 100644 --- a/tests/test_late_interaction_embeddings.py +++ b/tests/test_late_interaction_embeddings.py @@ -310,22 +310,26 @@ def test_embedding_size(): delete_model_cache(model.model._model_dir) -@pytest.mark.parametrize("model_name", ["colbert-ir/colbertv2.0"]) +@pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"]) def test_tokenize(model_name: str) -> None: is_ci = os.getenv("CI") model = LateInteractionTextEmbedding(model_name=model_name) texts = ["hello world", "flag embedding"] - encodings = model.tokenize(texts, is_doc=True) - assert len(encodings) == 2 - for encoding in encodings: + enc_doc = model.tokenize(texts, is_doc=True) + assert len(enc_doc) == 2 + for encoding in enc_doc: assert encoding.ids is not None assert len(encoding.ids) > 0 - encodings = model.tokenize(["hello world"], is_doc=False) - assert len(encodings) == 1 - assert encodings[0].ids is not None - assert len(encodings[0].ids) > 0 + enc_query = model.tokenize(["hello world"], is_doc=False) + assert len(enc_query) == 1 + assert enc_query[0].ids is not None + assert len(enc_query[0].ids) > 0 + + doc_ids = list(enc_doc[0].ids) + query_ids = list(enc_query[0].ids) + assert doc_ids != query_ids if is_ci: delete_model_cache(model.model._model_dir) From 7dba1d0c53d53c5c4826d1ef4deb1fd74e8a00a3 Mon Sep 17 00:00:00 2001 From: George Date: Fri, 31 Oct 2025 01:47:38 +0700 Subject: [PATCH 11/15] refactor: split tokenize into _tokenize and tokenize to respect MRO (#566) --- fastembed/late_interaction/colbert.py | 11 ++- .../late_interaction_embedding_base.py | 2 +- .../late_interaction_text_embedding.py | 6 +- .../late_interaction/token_embeddings.py | 2 +- .../late_interaction_multimodal/colpali.py | 7 +- .../late_interaction_multimodal_embedding.py | 6 +- ...e_interaction_multimodal_embedding_base.py | 2 +- .../onnx_multimodal_model.py | 6 +- .../rerank/cross_encoder/onnx_text_model.py | 2 - fastembed/sparse/bm25.py | 22 +++--- fastembed/sparse/bm42.py | 6 +- fastembed/sparse/minicoil.py | 6 +- fastembed/sparse/sparse_embedding_base.py | 2 +- fastembed/sparse/sparse_text_embedding.py | 6 +- fastembed/sparse/splade_pp.py | 8 +- fastembed/text/onnx_embedding.py | 15 +--- fastembed/text/onnx_text_model.py | 6 +- fastembed/text/text_embedding.py | 6 +- fastembed/text/text_embedding_base.py | 2 +- tests/test_late_interaction_embeddings.py | 33 +++++---- tests/test_sparse_embeddings.py | 73 +++---------------- tests/test_text_onnx_embeddings.py | 5 +- 22 files changed, 90 insertions(+), 144 deletions(-) diff --git a/fastembed/late_interaction/colbert.py b/fastembed/late_interaction/colbert.py index 327d1429..fa742882 100644 --- a/fastembed/late_interaction/colbert.py +++ b/fastembed/late_interaction/colbert.py @@ -80,11 +80,16 @@ def _preprocess_onnx_input( ) return onnx_input - def tokenize(self, texts: list[str], is_doc: bool = True, **kwargs: Any) -> list[Encoding]: # type: ignore[override] + def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: + return self._tokenize(documents, **kwargs) + + def _tokenize( + self, documents: list[str], is_doc: bool = True, **kwargs: Any + ) -> list[Encoding]: return ( - self._tokenize_documents(documents=texts) + self._tokenize_documents(documents=documents) if is_doc - else self._tokenize_query(query=next(iter(texts))) + else self._tokenize_query(query=next(iter(documents))) ) def _tokenize_query(self, query: str) -> list[Encoding]: diff --git a/fastembed/late_interaction/late_interaction_embedding_base.py b/fastembed/late_interaction/late_interaction_embedding_base.py index e6f00d8c..e5ae230a 100644 --- a/fastembed/late_interaction/late_interaction_embedding_base.py +++ b/fastembed/late_interaction/late_interaction_embedding_base.py @@ -21,7 +21,7 @@ def __init__( self._local_files_only = kwargs.pop("local_files_only", False) self._embedding_size: Optional[int] = None - def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: + def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: raise NotImplementedError() def embed( diff --git a/fastembed/late_interaction/late_interaction_text_embedding.py b/fastembed/late_interaction/late_interaction_text_embedding.py index e52e8154..3d3b41d1 100644 --- a/fastembed/late_interaction/late_interaction_text_embedding.py +++ b/fastembed/late_interaction/late_interaction_text_embedding.py @@ -116,18 +116,18 @@ def get_embedding_size(cls, model_name: str) -> int: ) return embedding_size - def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: + def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: """ Tokenize input texts using the model's tokenizer. Args: - texts: List of strings to tokenize + documents: List of strings to tokenize **kwargs: Additional arguments passed to the tokenizer Returns: List of tokenizer Encodings """ - return self.model.tokenize(texts, **kwargs) + return self.model.tokenize(documents, **kwargs) def embed( self, diff --git a/fastembed/late_interaction/token_embeddings.py b/fastembed/late_interaction/token_embeddings.py index 8a00a146..ec4844ba 100644 --- a/fastembed/late_interaction/token_embeddings.py +++ b/fastembed/late_interaction/token_embeddings.py @@ -25,7 +25,7 @@ ] -class TokenEmbeddingsModel(OnnxTextEmbedding, LateInteractionTextEmbeddingBase): # type: ignore[misc] +class TokenEmbeddingsModel(OnnxTextEmbedding, LateInteractionTextEmbeddingBase): @classmethod def _list_supported_models(cls) -> list[DenseModelDescription]: """Lists the supported models. diff --git a/fastembed/late_interaction_multimodal/colpali.py b/fastembed/late_interaction_multimodal/colpali.py index 7242aaad..7b466b24 100644 --- a/fastembed/late_interaction_multimodal/colpali.py +++ b/fastembed/late_interaction_multimodal/colpali.py @@ -160,9 +160,12 @@ def _post_process_onnx_text_output( """ return output.model_output - def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: # type: ignore[override] + def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: + return self._tokenize(documents, **kwargs) + + def _tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: texts_query: list[str] = [] - for query in texts: + for query in documents: query = self.BOS_TOKEN + self.QUERY_PREFIX + query + self.PAD_TOKEN * 10 query += "\n" diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py index cc534909..e7df631d 100644 --- a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py @@ -119,18 +119,18 @@ def get_embedding_size(cls, model_name: str) -> int: ) return embedding_size - def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: + def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: """ Tokenize input texts using the model's tokenizer. Args: - texts: List of strings to tokenize + documents: List of strings to tokenize **kwargs: Additional arguments passed to the tokenizer Returns: List of tokenizer Encodings """ - return self.model.tokenize(texts, **kwargs) + return self.model.tokenize(documents, **kwargs) def embed_text( self, diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py index be39e25c..bd5d04be 100644 --- a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py @@ -22,7 +22,7 @@ def __init__( self._local_files_only = kwargs.pop("local_files_only", False) self._embedding_size: Optional[int] = None - def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: + def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: raise NotImplementedError() def embed_text( diff --git a/fastembed/late_interaction_multimodal/onnx_multimodal_model.py b/fastembed/late_interaction_multimodal/onnx_multimodal_model.py index 8b7a5ce0..c2f22b9f 100644 --- a/fastembed/late_interaction_multimodal/onnx_multimodal_model.py +++ b/fastembed/late_interaction_multimodal/onnx_multimodal_model.py @@ -80,17 +80,17 @@ def _load_onnx_model( def load_onnx_model(self) -> None: raise NotImplementedError("Subclasses must implement this method") - def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: + def _tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: if self.tokenizer is None: raise RuntimeError("Tokenizer not initialized") - return self.tokenizer.encode_batch(texts, **kwargs) # type: ignore[union-attr] + return self.tokenizer.encode_batch(documents, **kwargs) # type: ignore[union-attr] def onnx_embed_text( self, documents: list[str], **kwargs: Any, ) -> OnnxOutputContext: - encoded = self.tokenize(documents, **kwargs) + encoded = self._tokenize(documents, **kwargs) input_ids = np.array([e.ids for e in encoded]) attention_mask = np.array([e.attention_mask for e in encoded]) # type: ignore[union-attr] input_names = {node.name for node in self.model.get_inputs()} # type: ignore[union-attr] diff --git a/fastembed/rerank/cross_encoder/onnx_text_model.py b/fastembed/rerank/cross_encoder/onnx_text_model.py index cda7ddda..48ee5792 100644 --- a/fastembed/rerank/cross_encoder/onnx_text_model.py +++ b/fastembed/rerank/cross_encoder/onnx_text_model.py @@ -46,8 +46,6 @@ def _load_onnx_model( assert self.tokenizer is not None def tokenize(self, pairs: list[tuple[str, str]], **kwargs: Any) -> list[Encoding]: - if self.tokenizer is None: - raise RuntimeError("Tokenizer not initialized") return self.tokenizer.encode_batch(pairs, **kwargs) # type: ignore[union-attr] def _build_onnx_input(self, tokenized_input: list[Encoding]) -> dict[str, NumpyArray]: diff --git a/fastembed/sparse/bm25.py b/fastembed/sparse/bm25.py index 4c92e478..4ec4ec99 100644 --- a/fastembed/sparse/bm25.py +++ b/fastembed/sparse/bm25.py @@ -137,25 +137,25 @@ def __init__( self.tokenizer = SimpleTokenizer - def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: + def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: """Tokenize texts using SimpleTokenizer. Returns a list of simple Encoding-like objects with token strings. Note: BM25 uses a simple word tokenizer, not a learned tokenizer. """ result = [] - for text in texts: - tokens = self.tokenizer.tokenize(text) - # Create a simple object that mimics Encoding interface - class SimpleEncoding: - def __init__(self, tokens: list[str]): - self.tokens = tokens - self.ids = tokens # For BM25, tokens are the IDs - self.attention_mask = [1] * len(tokens) + class SimpleEncoding: + def __init__(self, tokens: list[str]): + self.tokens = tokens + self.ids = tokens # For BM25, tokens are the IDs + self.attention_mask = [1] * len(tokens) + + for document in documents: + tokens = self.tokenizer.tokenize(document) + result.append(SimpleEncoding(tokens)) - result.append(SimpleEncoding(tokens)) # type: ignore[arg-type] - return result # type: ignore[return-value] + return result @classmethod def _list_supported_models(cls) -> list[SparseModelDescription]: diff --git a/fastembed/sparse/bm42.py b/fastembed/sparse/bm42.py index e48c789c..937747a8 100644 --- a/fastembed/sparse/bm42.py +++ b/fastembed/sparse/bm42.py @@ -45,7 +45,7 @@ def get_language_by_model_name(model_name: str) -> str: return MODEL_TO_LANGUAGE[model_name.lower()] -class Bm42(SparseTextEmbeddingBase, OnnxTextModel[SparseEmbedding]): # type: ignore[misc] +class Bm42(SparseTextEmbeddingBase, OnnxTextModel[SparseEmbedding]): """ Bm42 is an extension of BM25, which tries to better evaluate importance of tokens in the documents, by extracting attention weights from the transformer model. @@ -139,8 +139,8 @@ def __init__( if not self.lazy_load: self.load_onnx_model() - def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: - return OnnxTextModel.tokenize(self, list(texts), **kwargs) + def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: + return self._tokenize(documents, **kwargs) def load_onnx_model(self) -> None: self._load_onnx_model( diff --git a/fastembed/sparse/minicoil.py b/fastembed/sparse/minicoil.py index 2a049a69..7490f588 100644 --- a/fastembed/sparse/minicoil.py +++ b/fastembed/sparse/minicoil.py @@ -58,7 +58,7 @@ def get_language_by_model_name(model_name: str) -> str: return MODEL_TO_LANGUAGE[model_name.lower()] -class MiniCOIL(SparseTextEmbeddingBase, OnnxTextModel[SparseEmbedding]): # type: ignore[misc] +class MiniCOIL(SparseTextEmbeddingBase, OnnxTextModel[SparseEmbedding]): """ MiniCOIL is a sparse embedding model, that resolves semantic meaning of the words, while keeping exact keyword match behavior. @@ -145,8 +145,8 @@ def __init__( if not self.lazy_load: self.load_onnx_model() - def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: - return OnnxTextModel.tokenize(self, list(texts), **kwargs) + def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: + return self._tokenize(documents, **kwargs) def load_onnx_model(self) -> None: self._load_onnx_model( diff --git a/fastembed/sparse/sparse_embedding_base.py b/fastembed/sparse/sparse_embedding_base.py index 1cdcb2c2..86907ece 100644 --- a/fastembed/sparse/sparse_embedding_base.py +++ b/fastembed/sparse/sparse_embedding_base.py @@ -45,7 +45,7 @@ def __init__( self.threads = threads self._local_files_only = kwargs.pop("local_files_only", False) - def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: + def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: raise NotImplementedError() def embed( diff --git a/fastembed/sparse/sparse_text_embedding.py b/fastembed/sparse/sparse_text_embedding.py index 16a1ca06..399e15b2 100644 --- a/fastembed/sparse/sparse_text_embedding.py +++ b/fastembed/sparse/sparse_text_embedding.py @@ -93,18 +93,18 @@ def __init__( "Please check the supported models using `SparseTextEmbedding.list_supported_models()`" ) - def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: + def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: """ Tokenize input texts using the model's tokenizer. Args: - texts: List of strings to tokenize + documents: List of strings to tokenize **kwargs: Additional arguments passed to the tokenizer Returns: List of tokenizer Encodings """ - return self.model.tokenize(texts, **kwargs) + return self.model.tokenize(documents, **kwargs) def embed( self, diff --git a/fastembed/sparse/splade_pp.py b/fastembed/sparse/splade_pp.py index 240e96a6..a3fa2478 100644 --- a/fastembed/sparse/splade_pp.py +++ b/fastembed/sparse/splade_pp.py @@ -137,20 +137,18 @@ def load_onnx_model(self) -> None: device_id=self.device_id, ) - def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: # type: ignore[override] + def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: """ Tokenize input texts using the model's tokenizer. Args: - texts: List of strings to tokenize + documents: List of strings to tokenize **kwargs: Additional arguments passed to the tokenizer Returns: List of tokenizer Encodings """ - if self.tokenizer is None: - raise RuntimeError("Tokenizer not initialized") - return self.tokenizer.encode_batch(texts, **kwargs) + return self._tokenize(documents, **kwargs) def embed( self, diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index 24d142be..df864451 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -185,7 +185,7 @@ ] -class OnnxTextEmbedding(TextEmbeddingBase, OnnxTextModel[NumpyArray]): # type: ignore[misc] +class OnnxTextEmbedding(TextEmbeddingBase, OnnxTextModel[NumpyArray]): """Implementation of the Flag Embedding model.""" @classmethod @@ -321,17 +321,8 @@ def _post_process_onnx_output( raise ValueError(f"Unsupported embedding shape: {embeddings.shape}") return normalize(processed_embeddings) - def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: - """Tokenize the input texts. - - Args: - texts: A list of strings to tokenize. - **kwargs: Additional keyword arguments. - - Returns: - list[Encoding]: List of tokenized encodings. - """ - return OnnxTextModel.tokenize(self, texts, **kwargs) + def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: + return self._tokenize(documents, **kwargs) def load_onnx_model(self) -> None: self._load_onnx_model( diff --git a/fastembed/text/onnx_text_model.py b/fastembed/text/onnx_text_model.py index 5a6e0a6f..f71ebe2c 100644 --- a/fastembed/text/onnx_text_model.py +++ b/fastembed/text/onnx_text_model.py @@ -68,15 +68,15 @@ def _load_onnx_model( def load_onnx_model(self) -> None: raise NotImplementedError("Subclasses must implement this method") - def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: - return self.tokenizer.encode_batch(texts) # type: ignore[union-attr] + def _tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: + return self.tokenizer.encode_batch(documents) # type:ignore[union-attr] def onnx_embed( self, documents: list[str], **kwargs: Any, ) -> OnnxOutputContext: - encoded = self.tokenize(documents, **kwargs) + encoded = self._tokenize(documents, **kwargs) input_ids = np.array([e.ids for e in encoded]) attention_mask = np.array([e.attention_mask for e in encoded]) input_names = {node.name for node in self.model.get_inputs()} # type: ignore[union-attr] diff --git a/fastembed/text/text_embedding.py b/fastembed/text/text_embedding.py index 2e56e4af..aee079de 100644 --- a/fastembed/text/text_embedding.py +++ b/fastembed/text/text_embedding.py @@ -163,18 +163,18 @@ def get_embedding_size(cls, model_name: str) -> int: ) return embedding_size - def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: + def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: """ Tokenize input texts using the model's tokenizer. Args: - texts: List of strings to tokenize + documents: List of strings to tokenize **kwargs: Additional arguments passed to the tokenizer Returns: List of tokenizer Encodings """ - return self.model.tokenize(texts, **kwargs) + return self.model.tokenize(documents, **kwargs) def embed( self, diff --git a/fastembed/text/text_embedding_base.py b/fastembed/text/text_embedding_base.py index f9dabf90..ca7ef3a9 100644 --- a/fastembed/text/text_embedding_base.py +++ b/fastembed/text/text_embedding_base.py @@ -20,7 +20,7 @@ def __init__( self._local_files_only = kwargs.pop("local_files_only", False) self._embedding_size: Optional[int] = None - def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: + def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: raise NotImplementedError() def embed( diff --git a/tests/test_late_interaction_embeddings.py b/tests/test_late_interaction_embeddings.py index c037bef9..e3cbc6ca 100644 --- a/tests/test_late_interaction_embeddings.py +++ b/tests/test_late_interaction_embeddings.py @@ -249,23 +249,27 @@ def test_single_embedding_query(model_cache, model_name: str): @pytest.mark.parametrize("token_dim,model_name", [(96, "answerdotai/answerai-colbert-small-v1")]) -def test_parallel_processing(model_cache, token_dim: int, model_name: str): - with model_cache(model_name) as model: - docs = ["hello world", "flag embedding"] * 100 - embeddings = list(model.embed(docs, batch_size=10, parallel=2)) - embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None)) +def test_parallel_processing(token_dim: int, model_name: str): + # this test loads a copy of a model per process, might cause oom in parallel=0 on machines with + # an insufficient mem-to-cpus-ratio + is_ci = os.getenv("CI") + model = LateInteractionTextEmbedding(model_name=model_name) + docs = ["hello world", "flag embedding"] * 100 + embeddings = list(model.embed(docs, batch_size=10, parallel=2)) + + embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None)) - # embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0)) # inherits OnnxTextModel which - # # is tested in TextEmbedding, disabling it here to reduce number of requests to hf - # # multiprocessing is enough to test with `parallel=2`, and `parallel=None` is okay to tests since it reuses - # # model from cache + # embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0)) # inherits OnnxTextModel which + # # is tested in TextEmbedding, disabling it here to reduce number of requests to hf + # # multiprocessing is enough to test with `parallel=2`, and `parallel=None` is okay to tests since it reuses + # # model from cache - assert len(embeddings) == len(docs) and embeddings[0].shape[-1] == token_dim + assert len(embeddings) == len(docs) and embeddings[0].shape[-1] == token_dim - for i in range(len(embeddings)): - assert np.allclose(embeddings[i], embeddings_2[i], atol=1e-3) - # assert np.allclose(embeddings[i], embeddings_3[i], atol=1e-3) + for i in range(len(embeddings)): + assert np.allclose(embeddings[i], embeddings_2[i], atol=1e-3) + # assert np.allclose(embeddings[i], embeddings_3[i], atol=1e-3) @pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"]) @@ -325,7 +329,8 @@ def test_tokenize(model_name: str) -> None: enc_query = model.tokenize(["hello world"], is_doc=False) assert len(enc_query) == 1 assert enc_query[0].ids is not None - assert len(enc_query[0].ids) > 0 + assert len(enc_query[0].ids) == 31 # colbert requires query to be at least 32 tokens, + # padding is done during tokenization, the last token is added preprocess onnx input doc_ids = list(enc_doc[0].ids) query_ids = list(enc_query[0].ids) diff --git a/tests/test_sparse_embeddings.py b/tests/test_sparse_embeddings.py index f8eaed00..e942d07e 100644 --- a/tests/test_sparse_embeddings.py +++ b/tests/test_sparse_embeddings.py @@ -5,8 +5,6 @@ import numpy as np from fastembed.sparse.bm25 import Bm25 -from fastembed.sparse.bm42 import Bm42 -from fastembed.sparse.minicoil import MiniCOIL from fastembed.sparse.sparse_text_embedding import SparseTextEmbedding from tests.utils import delete_model_cache, should_test_model @@ -280,9 +278,18 @@ def test_lazy_load(model_name: str) -> None: delete_model_cache(model.model._model_dir) -@pytest.mark.parametrize("model_name", ["prithivida/Splade_PP_en_v1"]) +@pytest.mark.parametrize( + "model_name", + [ + "prithivida/Splade_PP_en_v1", + "Qdrant/bm25", + "Qdrant/bm42-all-minilm-l6-v2-attentions", + "Qdrant/minicoil-v1", + ], +) def test_tokenize(model_name: str) -> None: is_ci = os.getenv("CI") + model = SparseTextEmbedding(model_name=model_name) encodings = model.tokenize(["hello world"]) @@ -299,63 +306,3 @@ def test_tokenize(model_name: str) -> None: if is_ci: delete_model_cache(model.model._model_dir) - - -def test_tokenize_bm25() -> None: - is_ci = os.getenv("CI") - model = Bm25("Qdrant/bm25", language="english") - - encodings = model.tokenize(["hello world"]) - assert len(encodings) == 1 - assert encodings[0].ids is not None - assert len(encodings[0].ids) > 0 - - texts = ["hello world", "flag embedding"] - encodings = model.tokenize(texts) - assert len(encodings) == 2 - for encoding in encodings: - assert encoding.ids is not None - assert len(encoding.ids) > 0 - - if is_ci: - delete_model_cache(model._model_dir) - - -def test_tokenize_bm42() -> None: - is_ci = os.getenv("CI") - model = Bm42("Qdrant/bm42-all-minilm-l6-v2-attentions") - - encodings = model.tokenize(["hello world"]) - assert len(encodings) == 1 - assert encodings[0].ids is not None - assert len(encodings[0].ids) > 0 - - texts = ["hello world", "flag embedding"] - encodings = model.tokenize(texts) - assert len(encodings) == 2 - for encoding in encodings: - assert encoding.ids is not None - assert len(encoding.ids) > 0 - - if is_ci: - delete_model_cache(model._model_dir) - - -def test_tokenize_minicoil() -> None: - is_ci = os.getenv("CI") - model = MiniCOIL("Qdrant/minicoil-v1") - - encodings = model.tokenize(["hello world"]) - assert len(encodings) == 1 - assert encodings[0].ids is not None - assert len(encodings[0].ids) > 0 - - texts = ["hello world", "flag embedding"] - encodings = model.tokenize(texts) - assert len(encodings) == 2 - for encoding in encodings: - assert encoding.ids is not None - assert len(encoding.ids) > 0 - - if is_ci: - delete_model_cache(model._model_dir) diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index 65f01af0..b958aefe 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -195,10 +195,9 @@ def test_embedding_size() -> None: delete_model_cache(model.model._model_dir) -@pytest.mark.parametrize("model_name", ["BAAI/bge-small-en-v1.5"]) -def test_tokenize(model_name: str) -> None: +def test_tokenize() -> None: is_ci = os.getenv("CI") - model = TextEmbedding(model_name=model_name) + model = TextEmbedding() encodings = model.tokenize(["hello world"]) assert len(encodings) == 1 From b9d80a8a473e06f361cd64889db7c8166f13219d Mon Sep 17 00:00:00 2001 From: Daniel Boros <56868953+dancixx@users.noreply.github.com> Date: Mon, 10 Nov 2025 09:55:15 +0100 Subject: [PATCH 12/15] chore: remove tokenize impl from sparse embedding (#567) * chore: remove tokenize impl from sparse embedding * chore: remove tokenize test * fix: return type hint --------- Co-authored-by: George Panchuk --- fastembed/sparse/bm25.py | 22 ++-------------- fastembed/sparse/bm42.py | 5 ++-- fastembed/sparse/minicoil.py | 6 ++--- fastembed/sparse/sparse_embedding_base.py | 5 ++-- fastembed/sparse/sparse_text_embedding.py | 21 ++++----------- fastembed/sparse/splade_pp.py | 15 ++--------- tests/test_sparse_embeddings.py | 32 +---------------------- 7 files changed, 17 insertions(+), 89 deletions(-) diff --git a/fastembed/sparse/bm25.py b/fastembed/sparse/bm25.py index 4ec4ec99..80ea441b 100644 --- a/fastembed/sparse/bm25.py +++ b/fastembed/sparse/bm25.py @@ -7,7 +7,6 @@ import mmh3 import numpy as np from py_rust_stemmers import SnowballStemmer -from tokenizers import Encoding from fastembed.common.utils import ( define_cache_dir, iter_batch, @@ -137,25 +136,8 @@ def __init__( self.tokenizer = SimpleTokenizer - def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: - """Tokenize texts using SimpleTokenizer. - - Returns a list of simple Encoding-like objects with token strings. - Note: BM25 uses a simple word tokenizer, not a learned tokenizer. - """ - result = [] - - class SimpleEncoding: - def __init__(self, tokens: list[str]): - self.tokens = tokens - self.ids = tokens # For BM25, tokens are the IDs - self.attention_mask = [1] * len(tokens) - - for document in documents: - tokens = self.tokenizer.tokenize(document) - result.append(SimpleEncoding(tokens)) - - return result + def tokenize(self, documents: list[str], **kwargs: Any) -> dict[str, Any]: + raise NotImplementedError("Tokenize method for sparse embeddings is not implemented yet.") @classmethod def _list_supported_models(cls) -> list[SparseModelDescription]: diff --git a/fastembed/sparse/bm42.py b/fastembed/sparse/bm42.py index 937747a8..8025e671 100644 --- a/fastembed/sparse/bm42.py +++ b/fastembed/sparse/bm42.py @@ -16,7 +16,6 @@ ) from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker from fastembed.common.model_description import SparseModelDescription, ModelSource -from tokenizers import Encoding supported_bm42_models: list[SparseModelDescription] = [ SparseModelDescription( @@ -139,8 +138,8 @@ def __init__( if not self.lazy_load: self.load_onnx_model() - def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: - return self._tokenize(documents, **kwargs) + def tokenize(self, documents: list[str], **kwargs: Any) -> dict[str, Any]: + raise NotImplementedError("Tokenize method for sparse embeddings is not implemented yet.") def load_onnx_model(self) -> None: self._load_onnx_model( diff --git a/fastembed/sparse/minicoil.py b/fastembed/sparse/minicoil.py index 7490f588..9334a57b 100644 --- a/fastembed/sparse/minicoil.py +++ b/fastembed/sparse/minicoil.py @@ -5,7 +5,7 @@ import numpy as np from numpy.typing import NDArray from py_rust_stemmers import SnowballStemmer -from tokenizers import Tokenizer, Encoding +from tokenizers import Tokenizer from fastembed.common.model_description import SparseModelDescription, ModelSource from fastembed.common.onnx_model import OnnxOutputContext @@ -145,8 +145,8 @@ def __init__( if not self.lazy_load: self.load_onnx_model() - def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: - return self._tokenize(documents, **kwargs) + def tokenize(self, documents: list[str], **kwargs: Any) -> dict[str, Any]: + raise NotImplementedError("Tokenize method for sparse embeddings is not implemented yet.") def load_onnx_model(self) -> None: self._load_onnx_model( diff --git a/fastembed/sparse/sparse_embedding_base.py b/fastembed/sparse/sparse_embedding_base.py index 86907ece..c33291fc 100644 --- a/fastembed/sparse/sparse_embedding_base.py +++ b/fastembed/sparse/sparse_embedding_base.py @@ -3,7 +3,6 @@ import numpy as np from numpy.typing import NDArray -from tokenizers import Encoding from fastembed.common.model_description import SparseModelDescription from fastembed.common.types import NumpyArray @@ -45,8 +44,8 @@ def __init__( self.threads = threads self._local_files_only = kwargs.pop("local_files_only", False) - def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: - raise NotImplementedError() + def tokenize(self, documents: list[str], **kwargs: Any) -> dict[str, Any]: + raise NotImplementedError("Tokenize method for sparse embeddings is not implemented yet.") def embed( self, diff --git a/fastembed/sparse/sparse_text_embedding.py b/fastembed/sparse/sparse_text_embedding.py index 399e15b2..04fa6dab 100644 --- a/fastembed/sparse/sparse_text_embedding.py +++ b/fastembed/sparse/sparse_text_embedding.py @@ -1,9 +1,10 @@ -from typing import Any, Iterable, Optional, Sequence, Type, Union +import warnings from dataclasses import asdict +from typing import Any, Iterable, Optional, Sequence, Type, Union -from tokenizers import Encoding from fastembed.common import OnnxProvider +from fastembed.common.model_description import SparseModelDescription from fastembed.sparse.bm25 import Bm25 from fastembed.sparse.bm42 import Bm42 from fastembed.sparse.minicoil import MiniCOIL @@ -12,8 +13,6 @@ SparseTextEmbeddingBase, ) from fastembed.sparse.splade_pp import SpladePP -import warnings -from fastembed.common.model_description import SparseModelDescription class SparseTextEmbedding(SparseTextEmbeddingBase): @@ -93,18 +92,8 @@ def __init__( "Please check the supported models using `SparseTextEmbedding.list_supported_models()`" ) - def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: - """ - Tokenize input texts using the model's tokenizer. - - Args: - documents: List of strings to tokenize - **kwargs: Additional arguments passed to the tokenizer - - Returns: - List of tokenizer Encodings - """ - return self.model.tokenize(documents, **kwargs) + def tokenize(self, documents: list[str], **kwargs: Any) -> dict[str, Any]: + raise NotImplementedError("Tokenize method for sparse embeddings is not implemented yet.") def embed( self, diff --git a/fastembed/sparse/splade_pp.py b/fastembed/sparse/splade_pp.py index a3fa2478..70f66bc9 100644 --- a/fastembed/sparse/splade_pp.py +++ b/fastembed/sparse/splade_pp.py @@ -1,7 +1,6 @@ from typing import Any, Iterable, Optional, Sequence, Type, Union import numpy as np -from tokenizers import Encoding from fastembed.common import OnnxProvider from fastembed.common.onnx_model import OnnxOutputContext @@ -137,18 +136,8 @@ def load_onnx_model(self) -> None: device_id=self.device_id, ) - def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: - """ - Tokenize input texts using the model's tokenizer. - - Args: - documents: List of strings to tokenize - **kwargs: Additional arguments passed to the tokenizer - - Returns: - List of tokenizer Encodings - """ - return self._tokenize(documents, **kwargs) + def tokenize(self, documents: list[str], **kwargs: Any) -> dict[str, Any]: + raise NotImplementedError("Tokenize method for sparse embeddings is not implemented yet.") def embed( self, diff --git a/tests/test_sparse_embeddings.py b/tests/test_sparse_embeddings.py index e942d07e..0e68f6c1 100644 --- a/tests/test_sparse_embeddings.py +++ b/tests/test_sparse_embeddings.py @@ -1,8 +1,8 @@ import os from contextlib import contextmanager -import pytest import numpy as np +import pytest from fastembed.sparse.bm25 import Bm25 from fastembed.sparse.sparse_text_embedding import SparseTextEmbedding @@ -276,33 +276,3 @@ def test_lazy_load(model_name: str) -> None: if is_ci: delete_model_cache(model.model._model_dir) - - -@pytest.mark.parametrize( - "model_name", - [ - "prithivida/Splade_PP_en_v1", - "Qdrant/bm25", - "Qdrant/bm42-all-minilm-l6-v2-attentions", - "Qdrant/minicoil-v1", - ], -) -def test_tokenize(model_name: str) -> None: - is_ci = os.getenv("CI") - - model = SparseTextEmbedding(model_name=model_name) - - encodings = model.tokenize(["hello world"]) - assert len(encodings) == 1 - assert encodings[0].ids is not None - assert len(encodings[0].ids) > 0 - - texts = ["hello world", "flag embedding"] - encodings = model.tokenize(texts) - assert len(encodings) == 2 - for encoding in encodings: - assert encoding.ids is not None - assert len(encoding.ids) > 0 - - if is_ci: - delete_model_cache(model.model._model_dir) From 683eba2ad9dd3520d9a454b23be344fc942bff6b Mon Sep 17 00:00:00 2001 From: George Panchuk Date: Mon, 10 Nov 2025 16:40:25 +0700 Subject: [PATCH 13/15] chore: exception message --- fastembed/late_interaction/late_interaction_embedding_base.py | 2 +- .../late_interaction_multimodal_embedding_base.py | 2 +- fastembed/text/text_embedding_base.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fastembed/late_interaction/late_interaction_embedding_base.py b/fastembed/late_interaction/late_interaction_embedding_base.py index e5ae230a..1d9ba2dd 100644 --- a/fastembed/late_interaction/late_interaction_embedding_base.py +++ b/fastembed/late_interaction/late_interaction_embedding_base.py @@ -22,7 +22,7 @@ def __init__( self._embedding_size: Optional[int] = None def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: - raise NotImplementedError() + raise NotImplementedError("Subclasses must implement this method.") def embed( self, diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py index bd5d04be..bb3d7402 100644 --- a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py @@ -23,7 +23,7 @@ def __init__( self._embedding_size: Optional[int] = None def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: - raise NotImplementedError() + raise NotImplementedError("Subclasses must implement this method.") def embed_text( self, diff --git a/fastembed/text/text_embedding_base.py b/fastembed/text/text_embedding_base.py index ca7ef3a9..2100b549 100644 --- a/fastembed/text/text_embedding_base.py +++ b/fastembed/text/text_embedding_base.py @@ -21,7 +21,7 @@ def __init__( self._embedding_size: Optional[int] = None def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: - raise NotImplementedError() + raise NotImplementedError("Subclasses must implement this method.") def embed( self, From a26aad61d59651aa081e26734c6bd76ec60b8f0b Mon Sep 17 00:00:00 2001 From: Daniel Boros Date: Tue, 18 Nov 2025 14:32:02 +0100 Subject: [PATCH 14/15] feat: add public token_count across embedding models; return non-padded token counts; raise NotImplemented for sparse --- fastembed/late_interaction/colbert.py | 23 +++++++++++++++++++ .../late_interaction_embedding_base.py | 3 +++ .../late_interaction_text_embedding.py | 3 +++ .../late_interaction_multimodal/colpali.py | 10 ++++++++ .../late_interaction_multimodal_embedding.py | 3 +++ ...e_interaction_multimodal_embedding_base.py | 3 +++ fastembed/sparse/sparse_embedding_base.py | 3 +++ fastembed/sparse/sparse_text_embedding.py | 3 +++ fastembed/text/onnx_embedding.py | 10 ++++++++ fastembed/text/text_embedding.py | 3 +++ fastembed/text/text_embedding_base.py | 3 +++ 11 files changed, 67 insertions(+) diff --git a/fastembed/late_interaction/colbert.py b/fastembed/late_interaction/colbert.py index fa742882..ccebf958 100644 --- a/fastembed/late_interaction/colbert.py +++ b/fastembed/late_interaction/colbert.py @@ -83,6 +83,29 @@ def _preprocess_onnx_input( def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: return self._tokenize(documents, **kwargs) + def token_count(self, documents: list[str], is_doc: bool = True, **kwargs: Any) -> list[int]: + if not hasattr(self, "model") or self.model is None: + self.load_onnx_model() + counts: list[int] = [] + if is_doc: + encoded = self._tokenize_documents(documents=documents) + assert self.pad_token_id is not None + for e in encoded: + counts.append( + sum( + 1 + for tid in e.ids + if tid not in self.skip_list and tid != self.pad_token_id + ) + ) + else: + # query padding uses MASK token; exclude it from count + assert self.mask_token_id is not None + encoded = self._tokenize_query(query=next(iter(documents))) + for e in encoded: + counts.append(sum(1 for tid in e.ids if tid != self.mask_token_id)) + return counts + def _tokenize( self, documents: list[str], is_doc: bool = True, **kwargs: Any ) -> list[Encoding]: diff --git a/fastembed/late_interaction/late_interaction_embedding_base.py b/fastembed/late_interaction/late_interaction_embedding_base.py index 1d9ba2dd..5aaaa363 100644 --- a/fastembed/late_interaction/late_interaction_embedding_base.py +++ b/fastembed/late_interaction/late_interaction_embedding_base.py @@ -24,6 +24,9 @@ def __init__( def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: raise NotImplementedError("Subclasses must implement this method.") + def token_count(self, documents: list[str], **kwargs: Any) -> list[int]: + raise NotImplementedError("Subclasses must implement this method.") + def embed( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/late_interaction/late_interaction_text_embedding.py b/fastembed/late_interaction/late_interaction_text_embedding.py index 3d3b41d1..78976e2a 100644 --- a/fastembed/late_interaction/late_interaction_text_embedding.py +++ b/fastembed/late_interaction/late_interaction_text_embedding.py @@ -129,6 +129,9 @@ def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: """ return self.model.tokenize(documents, **kwargs) + def token_count(self, documents: list[str], **kwargs: Any) -> list[int]: + return self.model.token_count(documents, **kwargs) + def embed( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/late_interaction_multimodal/colpali.py b/fastembed/late_interaction_multimodal/colpali.py index 7b466b24..04743844 100644 --- a/fastembed/late_interaction_multimodal/colpali.py +++ b/fastembed/late_interaction_multimodal/colpali.py @@ -163,6 +163,16 @@ def _post_process_onnx_text_output( def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: return self._tokenize(documents, **kwargs) + def token_count(self, documents: list[str], **kwargs: Any) -> list[int]: + encoded = self._tokenize(documents, **kwargs) + counts: list[int] = [] + for e in encoded: + try: + counts.append(int(sum(e.attention_mask))) # type: ignore[arg-type] + except Exception: + counts.append(len(e.ids)) + return counts + def _tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: texts_query: list[str] = [] for query in documents: diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py index e7df631d..e256466b 100644 --- a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py @@ -132,6 +132,9 @@ def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: """ return self.model.tokenize(documents, **kwargs) + def token_count(self, documents: list[str], **kwargs: Any) -> list[int]: + return self.model.token_count(documents, **kwargs) + def embed_text( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py index bb3d7402..eb3e3eb0 100644 --- a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py @@ -25,6 +25,9 @@ def __init__( def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: raise NotImplementedError("Subclasses must implement this method.") + def token_count(self, documents: list[str], **kwargs: Any) -> list[int]: + raise NotImplementedError("Subclasses must implement this method.") + def embed_text( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/sparse/sparse_embedding_base.py b/fastembed/sparse/sparse_embedding_base.py index c33291fc..ba9b25f5 100644 --- a/fastembed/sparse/sparse_embedding_base.py +++ b/fastembed/sparse/sparse_embedding_base.py @@ -47,6 +47,9 @@ def __init__( def tokenize(self, documents: list[str], **kwargs: Any) -> dict[str, Any]: raise NotImplementedError("Tokenize method for sparse embeddings is not implemented yet.") + def token_count(self, documents: list[str], **kwargs: Any) -> list[int]: + raise NotImplementedError("Token count for sparse embeddings is not implemented.") + def embed( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/sparse/sparse_text_embedding.py b/fastembed/sparse/sparse_text_embedding.py index 04fa6dab..e2cab29d 100644 --- a/fastembed/sparse/sparse_text_embedding.py +++ b/fastembed/sparse/sparse_text_embedding.py @@ -95,6 +95,9 @@ def __init__( def tokenize(self, documents: list[str], **kwargs: Any) -> dict[str, Any]: raise NotImplementedError("Tokenize method for sparse embeddings is not implemented yet.") + def token_count(self, documents: list[str], **kwargs: Any) -> list[int]: + raise NotImplementedError("Token count for sparse embeddings is not implemented.") + def embed( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index df864451..432d09d0 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -324,6 +324,16 @@ def _post_process_onnx_output( def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: return self._tokenize(documents, **kwargs) + def token_count(self, documents: list[str], **kwargs: Any) -> list[int]: + encoded = self._tokenize(documents, **kwargs) + counts: list[int] = [] + for e in encoded: + try: + counts.append(int(sum(e.attention_mask))) # type: ignore[arg-type] + except Exception: + counts.append(len(e.ids)) + return counts + def load_onnx_model(self) -> None: self._load_onnx_model( model_dir=self._model_dir, diff --git a/fastembed/text/text_embedding.py b/fastembed/text/text_embedding.py index aee079de..a220b148 100644 --- a/fastembed/text/text_embedding.py +++ b/fastembed/text/text_embedding.py @@ -176,6 +176,9 @@ def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: """ return self.model.tokenize(documents, **kwargs) + def token_count(self, documents: list[str], **kwargs: Any) -> list[int]: + return self.model.token_count(documents, **kwargs) + def embed( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/text/text_embedding_base.py b/fastembed/text/text_embedding_base.py index 2100b549..e1ffcd8d 100644 --- a/fastembed/text/text_embedding_base.py +++ b/fastembed/text/text_embedding_base.py @@ -23,6 +23,9 @@ def __init__( def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: raise NotImplementedError("Subclasses must implement this method.") + def token_count(self, documents: list[str], **kwargs: Any) -> list[int]: + raise NotImplementedError("Subclasses must implement this method.") + def embed( self, documents: Union[str, Iterable[str]], From 42974dd4f885be38a6588e2b346d520de7e87cfa Mon Sep 17 00:00:00 2001 From: Daniel Boros Date: Tue, 18 Nov 2025 14:35:08 +0100 Subject: [PATCH 15/15] Revert "feat: add public token_count across embedding models; return non-padded token counts; raise NotImplemented for sparse" This reverts commit a26aad61d59651aa081e26734c6bd76ec60b8f0b. --- fastembed/late_interaction/colbert.py | 23 ------------------- .../late_interaction_embedding_base.py | 3 --- .../late_interaction_text_embedding.py | 3 --- .../late_interaction_multimodal/colpali.py | 10 -------- .../late_interaction_multimodal_embedding.py | 3 --- ...e_interaction_multimodal_embedding_base.py | 3 --- fastembed/sparse/sparse_embedding_base.py | 3 --- fastembed/sparse/sparse_text_embedding.py | 3 --- fastembed/text/onnx_embedding.py | 10 -------- fastembed/text/text_embedding.py | 3 --- fastembed/text/text_embedding_base.py | 3 --- 11 files changed, 67 deletions(-) diff --git a/fastembed/late_interaction/colbert.py b/fastembed/late_interaction/colbert.py index ccebf958..fa742882 100644 --- a/fastembed/late_interaction/colbert.py +++ b/fastembed/late_interaction/colbert.py @@ -83,29 +83,6 @@ def _preprocess_onnx_input( def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: return self._tokenize(documents, **kwargs) - def token_count(self, documents: list[str], is_doc: bool = True, **kwargs: Any) -> list[int]: - if not hasattr(self, "model") or self.model is None: - self.load_onnx_model() - counts: list[int] = [] - if is_doc: - encoded = self._tokenize_documents(documents=documents) - assert self.pad_token_id is not None - for e in encoded: - counts.append( - sum( - 1 - for tid in e.ids - if tid not in self.skip_list and tid != self.pad_token_id - ) - ) - else: - # query padding uses MASK token; exclude it from count - assert self.mask_token_id is not None - encoded = self._tokenize_query(query=next(iter(documents))) - for e in encoded: - counts.append(sum(1 for tid in e.ids if tid != self.mask_token_id)) - return counts - def _tokenize( self, documents: list[str], is_doc: bool = True, **kwargs: Any ) -> list[Encoding]: diff --git a/fastembed/late_interaction/late_interaction_embedding_base.py b/fastembed/late_interaction/late_interaction_embedding_base.py index 5aaaa363..1d9ba2dd 100644 --- a/fastembed/late_interaction/late_interaction_embedding_base.py +++ b/fastembed/late_interaction/late_interaction_embedding_base.py @@ -24,9 +24,6 @@ def __init__( def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: raise NotImplementedError("Subclasses must implement this method.") - def token_count(self, documents: list[str], **kwargs: Any) -> list[int]: - raise NotImplementedError("Subclasses must implement this method.") - def embed( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/late_interaction/late_interaction_text_embedding.py b/fastembed/late_interaction/late_interaction_text_embedding.py index 78976e2a..3d3b41d1 100644 --- a/fastembed/late_interaction/late_interaction_text_embedding.py +++ b/fastembed/late_interaction/late_interaction_text_embedding.py @@ -129,9 +129,6 @@ def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: """ return self.model.tokenize(documents, **kwargs) - def token_count(self, documents: list[str], **kwargs: Any) -> list[int]: - return self.model.token_count(documents, **kwargs) - def embed( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/late_interaction_multimodal/colpali.py b/fastembed/late_interaction_multimodal/colpali.py index 04743844..7b466b24 100644 --- a/fastembed/late_interaction_multimodal/colpali.py +++ b/fastembed/late_interaction_multimodal/colpali.py @@ -163,16 +163,6 @@ def _post_process_onnx_text_output( def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: return self._tokenize(documents, **kwargs) - def token_count(self, documents: list[str], **kwargs: Any) -> list[int]: - encoded = self._tokenize(documents, **kwargs) - counts: list[int] = [] - for e in encoded: - try: - counts.append(int(sum(e.attention_mask))) # type: ignore[arg-type] - except Exception: - counts.append(len(e.ids)) - return counts - def _tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: texts_query: list[str] = [] for query in documents: diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py index e256466b..e7df631d 100644 --- a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py @@ -132,9 +132,6 @@ def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: """ return self.model.tokenize(documents, **kwargs) - def token_count(self, documents: list[str], **kwargs: Any) -> list[int]: - return self.model.token_count(documents, **kwargs) - def embed_text( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py index eb3e3eb0..bb3d7402 100644 --- a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py @@ -25,9 +25,6 @@ def __init__( def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: raise NotImplementedError("Subclasses must implement this method.") - def token_count(self, documents: list[str], **kwargs: Any) -> list[int]: - raise NotImplementedError("Subclasses must implement this method.") - def embed_text( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/sparse/sparse_embedding_base.py b/fastembed/sparse/sparse_embedding_base.py index ba9b25f5..c33291fc 100644 --- a/fastembed/sparse/sparse_embedding_base.py +++ b/fastembed/sparse/sparse_embedding_base.py @@ -47,9 +47,6 @@ def __init__( def tokenize(self, documents: list[str], **kwargs: Any) -> dict[str, Any]: raise NotImplementedError("Tokenize method for sparse embeddings is not implemented yet.") - def token_count(self, documents: list[str], **kwargs: Any) -> list[int]: - raise NotImplementedError("Token count for sparse embeddings is not implemented.") - def embed( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/sparse/sparse_text_embedding.py b/fastembed/sparse/sparse_text_embedding.py index e2cab29d..04fa6dab 100644 --- a/fastembed/sparse/sparse_text_embedding.py +++ b/fastembed/sparse/sparse_text_embedding.py @@ -95,9 +95,6 @@ def __init__( def tokenize(self, documents: list[str], **kwargs: Any) -> dict[str, Any]: raise NotImplementedError("Tokenize method for sparse embeddings is not implemented yet.") - def token_count(self, documents: list[str], **kwargs: Any) -> list[int]: - raise NotImplementedError("Token count for sparse embeddings is not implemented.") - def embed( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index 432d09d0..df864451 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -324,16 +324,6 @@ def _post_process_onnx_output( def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: return self._tokenize(documents, **kwargs) - def token_count(self, documents: list[str], **kwargs: Any) -> list[int]: - encoded = self._tokenize(documents, **kwargs) - counts: list[int] = [] - for e in encoded: - try: - counts.append(int(sum(e.attention_mask))) # type: ignore[arg-type] - except Exception: - counts.append(len(e.ids)) - return counts - def load_onnx_model(self) -> None: self._load_onnx_model( model_dir=self._model_dir, diff --git a/fastembed/text/text_embedding.py b/fastembed/text/text_embedding.py index a220b148..aee079de 100644 --- a/fastembed/text/text_embedding.py +++ b/fastembed/text/text_embedding.py @@ -176,9 +176,6 @@ def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: """ return self.model.tokenize(documents, **kwargs) - def token_count(self, documents: list[str], **kwargs: Any) -> list[int]: - return self.model.token_count(documents, **kwargs) - def embed( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/text/text_embedding_base.py b/fastembed/text/text_embedding_base.py index e1ffcd8d..2100b549 100644 --- a/fastembed/text/text_embedding_base.py +++ b/fastembed/text/text_embedding_base.py @@ -23,9 +23,6 @@ def __init__( def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: raise NotImplementedError("Subclasses must implement this method.") - def token_count(self, documents: list[str], **kwargs: Any) -> list[int]: - raise NotImplementedError("Subclasses must implement this method.") - def embed( self, documents: Union[str, Iterable[str]],