diff --git a/fastembed/late_interaction/colbert.py b/fastembed/late_interaction/colbert.py index 3aa105ae..fa742882 100644 --- a/fastembed/late_interaction/colbert.py +++ b/fastembed/late_interaction/colbert.py @@ -80,7 +80,12 @@ def _preprocess_onnx_input( ) return onnx_input - def tokenize(self, documents: list[str], is_doc: bool = True, **kwargs: Any) -> list[Encoding]: + def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: + return self._tokenize(documents, **kwargs) + + def _tokenize( + self, documents: list[str], is_doc: bool = True, **kwargs: Any + ) -> list[Encoding]: return ( self._tokenize_documents(documents=documents) if is_doc diff --git a/fastembed/late_interaction/late_interaction_embedding_base.py b/fastembed/late_interaction/late_interaction_embedding_base.py index f677ba98..1d9ba2dd 100644 --- a/fastembed/late_interaction/late_interaction_embedding_base.py +++ b/fastembed/late_interaction/late_interaction_embedding_base.py @@ -1,5 +1,7 @@ from typing import Iterable, Optional, Union, Any +from tokenizers import Encoding + from fastembed.common.model_description import DenseModelDescription from fastembed.common.types import NumpyArray from fastembed.common.model_management import ModelManagement @@ -19,6 +21,9 @@ def __init__( self._local_files_only = kwargs.pop("local_files_only", False) self._embedding_size: Optional[int] = None + def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: + raise NotImplementedError("Subclasses must implement this method.") + def embed( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/late_interaction/late_interaction_text_embedding.py b/fastembed/late_interaction/late_interaction_text_embedding.py index 22833618..3d3b41d1 100644 --- a/fastembed/late_interaction/late_interaction_text_embedding.py +++ b/fastembed/late_interaction/late_interaction_text_embedding.py @@ -1,6 +1,8 @@ from typing import Any, Iterable, Optional, Sequence, Type, Union from dataclasses import asdict +from tokenizers import Encoding + from fastembed.common.model_description import DenseModelDescription from fastembed.common.types import NumpyArray from fastembed.common import OnnxProvider @@ -114,6 +116,19 @@ def get_embedding_size(cls, model_name: str) -> int: ) return embedding_size + def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: + """ + Tokenize input texts using the model's tokenizer. + + Args: + documents: List of strings to tokenize + **kwargs: Additional arguments passed to the tokenizer + + Returns: + List of tokenizer Encodings + """ + return self.model.tokenize(documents, **kwargs) + def embed( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/late_interaction_multimodal/colpali.py b/fastembed/late_interaction_multimodal/colpali.py index 9fab9535..7b466b24 100644 --- a/fastembed/late_interaction_multimodal/colpali.py +++ b/fastembed/late_interaction_multimodal/colpali.py @@ -161,13 +161,17 @@ def _post_process_onnx_text_output( return output.model_output def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: + return self._tokenize(documents, **kwargs) + + def _tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: texts_query: list[str] = [] for query in documents: query = self.BOS_TOKEN + self.QUERY_PREFIX + query + self.PAD_TOKEN * 10 query += "\n" texts_query.append(query) - encoded = self.tokenizer.encode_batch(texts_query) # type: ignore[union-attr] + assert self.tokenizer is not None + encoded = self.tokenizer.encode_batch(texts_query) return encoded def _preprocess_onnx_text_input( diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py index 39c1763e..e7df631d 100644 --- a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py @@ -1,6 +1,8 @@ from typing import Any, Iterable, Optional, Sequence, Type, Union from dataclasses import asdict +from tokenizers import Encoding + from fastembed.common import OnnxProvider, ImageInput from fastembed.common.types import NumpyArray from fastembed.late_interaction_multimodal.colpali import ColPali @@ -117,6 +119,19 @@ def get_embedding_size(cls, model_name: str) -> int: ) return embedding_size + def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: + """ + Tokenize input texts using the model's tokenizer. + + Args: + documents: List of strings to tokenize + **kwargs: Additional arguments passed to the tokenizer + + Returns: + List of tokenizer Encodings + """ + return self.model.tokenize(documents, **kwargs) + def embed_text( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py index 12e3553c..bb3d7402 100644 --- a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py @@ -1,5 +1,6 @@ from typing import Iterable, Optional, Union, Any +from tokenizers import Encoding from fastembed.common import ImageInput from fastembed.common.model_description import DenseModelDescription @@ -21,6 +22,9 @@ def __init__( self._local_files_only = kwargs.pop("local_files_only", False) self._embedding_size: Optional[int] = None + def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: + raise NotImplementedError("Subclasses must implement this method.") + def embed_text( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/late_interaction_multimodal/onnx_multimodal_model.py b/fastembed/late_interaction_multimodal/onnx_multimodal_model.py index 83706a2b..c2f22b9f 100644 --- a/fastembed/late_interaction_multimodal/onnx_multimodal_model.py +++ b/fastembed/late_interaction_multimodal/onnx_multimodal_model.py @@ -80,15 +80,17 @@ def _load_onnx_model( def load_onnx_model(self) -> None: raise NotImplementedError("Subclasses must implement this method") - def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: - return self.tokenizer.encode_batch(documents) # type: ignore[union-attr] + def _tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: + if self.tokenizer is None: + raise RuntimeError("Tokenizer not initialized") + return self.tokenizer.encode_batch(documents, **kwargs) # type: ignore[union-attr] def onnx_embed_text( self, documents: list[str], **kwargs: Any, ) -> OnnxOutputContext: - encoded = self.tokenize(documents, **kwargs) + encoded = self._tokenize(documents, **kwargs) input_ids = np.array([e.ids for e in encoded]) attention_mask = np.array([e.attention_mask for e in encoded]) # type: ignore[union-attr] input_names = {node.name for node in self.model.get_inputs()} # type: ignore[union-attr] diff --git a/fastembed/rerank/cross_encoder/onnx_text_model.py b/fastembed/rerank/cross_encoder/onnx_text_model.py index 3fc4e81c..48ee5792 100644 --- a/fastembed/rerank/cross_encoder/onnx_text_model.py +++ b/fastembed/rerank/cross_encoder/onnx_text_model.py @@ -45,8 +45,8 @@ def _load_onnx_model( self.tokenizer, _ = load_tokenizer(model_dir=model_dir) assert self.tokenizer is not None - def tokenize(self, pairs: list[tuple[str, str]], **_: Any) -> list[Encoding]: - return self.tokenizer.encode_batch(pairs) # type: ignore[union-attr] + def tokenize(self, pairs: list[tuple[str, str]], **kwargs: Any) -> list[Encoding]: + return self.tokenizer.encode_batch(pairs, **kwargs) # type: ignore[union-attr] def _build_onnx_input(self, tokenized_input: list[Encoding]) -> dict[str, NumpyArray]: input_names: set[str] = {node.name for node in self.model.get_inputs()} # type: ignore[union-attr] diff --git a/fastembed/sparse/bm25.py b/fastembed/sparse/bm25.py index b6ac59fd..80ea441b 100644 --- a/fastembed/sparse/bm25.py +++ b/fastembed/sparse/bm25.py @@ -136,6 +136,9 @@ def __init__( self.tokenizer = SimpleTokenizer + def tokenize(self, documents: list[str], **kwargs: Any) -> dict[str, Any]: + raise NotImplementedError("Tokenize method for sparse embeddings is not implemented yet.") + @classmethod def _list_supported_models(cls) -> list[SparseModelDescription]: """Lists the supported models. diff --git a/fastembed/sparse/bm42.py b/fastembed/sparse/bm42.py index 3e51404f..8025e671 100644 --- a/fastembed/sparse/bm42.py +++ b/fastembed/sparse/bm42.py @@ -138,6 +138,9 @@ def __init__( if not self.lazy_load: self.load_onnx_model() + def tokenize(self, documents: list[str], **kwargs: Any) -> dict[str, Any]: + raise NotImplementedError("Tokenize method for sparse embeddings is not implemented yet.") + def load_onnx_model(self) -> None: self._load_onnx_model( model_dir=self._model_dir, diff --git a/fastembed/sparse/minicoil.py b/fastembed/sparse/minicoil.py index efaa9abb..9334a57b 100644 --- a/fastembed/sparse/minicoil.py +++ b/fastembed/sparse/minicoil.py @@ -145,6 +145,9 @@ def __init__( if not self.lazy_load: self.load_onnx_model() + def tokenize(self, documents: list[str], **kwargs: Any) -> dict[str, Any]: + raise NotImplementedError("Tokenize method for sparse embeddings is not implemented yet.") + def load_onnx_model(self) -> None: self._load_onnx_model( model_dir=self._model_dir, diff --git a/fastembed/sparse/sparse_embedding_base.py b/fastembed/sparse/sparse_embedding_base.py index b153c814..c33291fc 100644 --- a/fastembed/sparse/sparse_embedding_base.py +++ b/fastembed/sparse/sparse_embedding_base.py @@ -44,6 +44,9 @@ def __init__( self.threads = threads self._local_files_only = kwargs.pop("local_files_only", False) + def tokenize(self, documents: list[str], **kwargs: Any) -> dict[str, Any]: + raise NotImplementedError("Tokenize method for sparse embeddings is not implemented yet.") + def embed( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/sparse/sparse_text_embedding.py b/fastembed/sparse/sparse_text_embedding.py index 3cb14c3e..04fa6dab 100644 --- a/fastembed/sparse/sparse_text_embedding.py +++ b/fastembed/sparse/sparse_text_embedding.py @@ -1,7 +1,10 @@ -from typing import Any, Iterable, Optional, Sequence, Type, Union +import warnings from dataclasses import asdict +from typing import Any, Iterable, Optional, Sequence, Type, Union + from fastembed.common import OnnxProvider +from fastembed.common.model_description import SparseModelDescription from fastembed.sparse.bm25 import Bm25 from fastembed.sparse.bm42 import Bm42 from fastembed.sparse.minicoil import MiniCOIL @@ -10,8 +13,6 @@ SparseTextEmbeddingBase, ) from fastembed.sparse.splade_pp import SpladePP -import warnings -from fastembed.common.model_description import SparseModelDescription class SparseTextEmbedding(SparseTextEmbeddingBase): @@ -91,6 +92,9 @@ def __init__( "Please check the supported models using `SparseTextEmbedding.list_supported_models()`" ) + def tokenize(self, documents: list[str], **kwargs: Any) -> dict[str, Any]: + raise NotImplementedError("Tokenize method for sparse embeddings is not implemented yet.") + def embed( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/sparse/splade_pp.py b/fastembed/sparse/splade_pp.py index d2c4af38..70f66bc9 100644 --- a/fastembed/sparse/splade_pp.py +++ b/fastembed/sparse/splade_pp.py @@ -1,6 +1,7 @@ from typing import Any, Iterable, Optional, Sequence, Type, Union import numpy as np + from fastembed.common import OnnxProvider from fastembed.common.onnx_model import OnnxOutputContext from fastembed.common.utils import define_cache_dir @@ -135,6 +136,9 @@ def load_onnx_model(self) -> None: device_id=self.device_id, ) + def tokenize(self, documents: list[str], **kwargs: Any) -> dict[str, Any]: + raise NotImplementedError("Tokenize method for sparse embeddings is not implemented yet.") + def embed( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index 4cc892f5..df864451 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -1,5 +1,7 @@ from typing import Any, Iterable, Optional, Sequence, Type, Union +from tokenizers import Encoding + from fastembed.common.types import NumpyArray, OnnxProvider from fastembed.common.onnx_model import OnnxOutputContext from fastembed.common.utils import define_cache_dir, normalize @@ -319,6 +321,9 @@ def _post_process_onnx_output( raise ValueError(f"Unsupported embedding shape: {embeddings.shape}") return normalize(processed_embeddings) + def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: + return self._tokenize(documents, **kwargs) + def load_onnx_model(self) -> None: self._load_onnx_model( model_dir=self._model_dir, diff --git a/fastembed/text/onnx_text_model.py b/fastembed/text/onnx_text_model.py index c939b21d..f71ebe2c 100644 --- a/fastembed/text/onnx_text_model.py +++ b/fastembed/text/onnx_text_model.py @@ -68,15 +68,15 @@ def _load_onnx_model( def load_onnx_model(self) -> None: raise NotImplementedError("Subclasses must implement this method") - def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: - return self.tokenizer.encode_batch(documents) # type: ignore[union-attr] + def _tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: + return self.tokenizer.encode_batch(documents) # type:ignore[union-attr] def onnx_embed( self, documents: list[str], **kwargs: Any, ) -> OnnxOutputContext: - encoded = self.tokenize(documents, **kwargs) + encoded = self._tokenize(documents, **kwargs) input_ids = np.array([e.ids for e in encoded]) attention_mask = np.array([e.attention_mask for e in encoded]) input_names = {node.name for node in self.model.get_inputs()} # type: ignore[union-attr] diff --git a/fastembed/text/text_embedding.py b/fastembed/text/text_embedding.py index 117f5af7..aee079de 100644 --- a/fastembed/text/text_embedding.py +++ b/fastembed/text/text_embedding.py @@ -1,5 +1,6 @@ import warnings from typing import Any, Iterable, Optional, Sequence, Type, Union +from tokenizers import Encoding from dataclasses import asdict from fastembed.common.types import NumpyArray, OnnxProvider @@ -162,6 +163,19 @@ def get_embedding_size(cls, model_name: str) -> int: ) return embedding_size + def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: + """ + Tokenize input texts using the model's tokenizer. + + Args: + documents: List of strings to tokenize + **kwargs: Additional arguments passed to the tokenizer + + Returns: + List of tokenizer Encodings + """ + return self.model.tokenize(documents, **kwargs) + def embed( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/text/text_embedding_base.py b/fastembed/text/text_embedding_base.py index 75df9ac5..2100b549 100644 --- a/fastembed/text/text_embedding_base.py +++ b/fastembed/text/text_embedding_base.py @@ -1,4 +1,5 @@ from typing import Iterable, Optional, Union, Any +from tokenizers import Encoding from fastembed.common.model_description import DenseModelDescription from fastembed.common.types import NumpyArray @@ -19,6 +20,9 @@ def __init__( self._local_files_only = kwargs.pop("local_files_only", False) self._embedding_size: Optional[int] = None + def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: + raise NotImplementedError("Subclasses must implement this method.") + def embed( self, documents: Union[str, Iterable[str]], diff --git a/tests/test_late_interaction_embeddings.py b/tests/test_late_interaction_embeddings.py index f89882f4..e3cbc6ca 100644 --- a/tests/test_late_interaction_embeddings.py +++ b/tests/test_late_interaction_embeddings.py @@ -249,23 +249,27 @@ def test_single_embedding_query(model_cache, model_name: str): @pytest.mark.parametrize("token_dim,model_name", [(96, "answerdotai/answerai-colbert-small-v1")]) -def test_parallel_processing(model_cache, token_dim: int, model_name: str): - with model_cache(model_name) as model: - docs = ["hello world", "flag embedding"] * 100 - embeddings = list(model.embed(docs, batch_size=10, parallel=2)) - embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None)) +def test_parallel_processing(token_dim: int, model_name: str): + # this test loads a copy of a model per process, might cause oom in parallel=0 on machines with + # an insufficient mem-to-cpus-ratio + is_ci = os.getenv("CI") + model = LateInteractionTextEmbedding(model_name=model_name) + docs = ["hello world", "flag embedding"] * 100 + embeddings = list(model.embed(docs, batch_size=10, parallel=2)) + + embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None)) - # embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0)) # inherits OnnxTextModel which - # # is tested in TextEmbedding, disabling it here to reduce number of requests to hf - # # multiprocessing is enough to test with `parallel=2`, and `parallel=None` is okay to tests since it reuses - # # model from cache + # embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0)) # inherits OnnxTextModel which + # # is tested in TextEmbedding, disabling it here to reduce number of requests to hf + # # multiprocessing is enough to test with `parallel=2`, and `parallel=None` is okay to tests since it reuses + # # model from cache - assert len(embeddings) == len(docs) and embeddings[0].shape[-1] == token_dim + assert len(embeddings) == len(docs) and embeddings[0].shape[-1] == token_dim - for i in range(len(embeddings)): - assert np.allclose(embeddings[i], embeddings_2[i], atol=1e-3) - # assert np.allclose(embeddings[i], embeddings_3[i], atol=1e-3) + for i in range(len(embeddings)): + assert np.allclose(embeddings[i], embeddings_2[i], atol=1e-3) + # assert np.allclose(embeddings[i], embeddings_3[i], atol=1e-3) @pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"]) @@ -308,3 +312,29 @@ def test_embedding_size(): assert model.embedding_size == 96 if is_ci: delete_model_cache(model.model._model_dir) + + +@pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"]) +def test_tokenize(model_name: str) -> None: + is_ci = os.getenv("CI") + model = LateInteractionTextEmbedding(model_name=model_name) + + texts = ["hello world", "flag embedding"] + enc_doc = model.tokenize(texts, is_doc=True) + assert len(enc_doc) == 2 + for encoding in enc_doc: + assert encoding.ids is not None + assert len(encoding.ids) > 0 + + enc_query = model.tokenize(["hello world"], is_doc=False) + assert len(enc_query) == 1 + assert enc_query[0].ids is not None + assert len(enc_query[0].ids) == 31 # colbert requires query to be at least 32 tokens, + # padding is done during tokenization, the last token is added preprocess onnx input + + doc_ids = list(enc_doc[0].ids) + query_ids = list(enc_query[0].ids) + assert doc_ids != query_ids + + if is_ci: + delete_model_cache(model.model._model_dir) diff --git a/tests/test_late_interaction_multimodal.py b/tests/test_late_interaction_multimodal.py index 80135f3b..2a1bf631 100644 --- a/tests/test_late_interaction_multimodal.py +++ b/tests/test_late_interaction_multimodal.py @@ -101,3 +101,23 @@ def test_embedding_size(): model_name = "Qdrant/ColPali-v1.3-fp16" model = LateInteractionMultimodalEmbedding(model_name=model_name, lazy_load=True) assert model.embedding_size == 128 + + +@pytest.mark.parametrize("model_name", ["Qdrant/colpali-v1.3-fp16"]) +def test_tokenize(model_name: str) -> None: + if os.getenv("CI"): + pytest.skip("Colpali is too large to test in CI") + + model = LateInteractionMultimodalEmbedding(model_name=model_name) + + encodings = model.tokenize(["hello world"]) + assert len(encodings) == 1 + assert encodings[0].ids is not None + assert len(encodings[0].ids) > 0 + + texts = ["hello world", "flag embedding"] + encodings = model.tokenize(texts) + assert len(encodings) == 2 + for encoding in encodings: + assert encoding.ids is not None + assert len(encoding.ids) > 0 diff --git a/tests/test_sparse_embeddings.py b/tests/test_sparse_embeddings.py index 90514967..0e68f6c1 100644 --- a/tests/test_sparse_embeddings.py +++ b/tests/test_sparse_embeddings.py @@ -1,8 +1,8 @@ import os from contextlib import contextmanager -import pytest import numpy as np +import pytest from fastembed.sparse.bm25 import Bm25 from fastembed.sparse.sparse_text_embedding import SparseTextEmbedding diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index 46ce6554..b958aefe 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -193,3 +193,23 @@ def test_embedding_size() -> None: if is_ci: delete_model_cache(model.model._model_dir) + + +def test_tokenize() -> None: + is_ci = os.getenv("CI") + model = TextEmbedding() + + encodings = model.tokenize(["hello world"]) + assert len(encodings) == 1 + assert encodings[0].ids is not None + assert len(encodings[0].ids) > 0 + + texts = ["hello world", "flag embedding"] + encodings = model.tokenize(texts) + assert len(encodings) == 2 + for encoding in encodings: + assert encoding.ids is not None + assert len(encoding.ids) > 0 + + if is_ci: + delete_model_cache(model.model._model_dir)