diff --git a/fastembed/common/model_description.py b/fastembed/common/model_description.py new file mode 100644 index 00000000..43e42f5f --- /dev/null +++ b/fastembed/common/model_description.py @@ -0,0 +1,40 @@ +from dataclasses import dataclass, field +from typing import Optional, Any + + +@dataclass(frozen=True) +class ModelSource: + hf: Optional[str] = None + url: Optional[str] = None + + def __post_init__(self) -> None: + if self.hf is None and self.url is None: + raise ValueError( + f"At least one source should be set, current sources: hf={self.hf}, url={self.url}" + ) + + +@dataclass(frozen=True) +class BaseModelDescription: + model: str + sources: ModelSource + model_file: str + description: str + license: str + size_in_GB: float + additional_files: list[str] = field(default_factory=list) + + +@dataclass(frozen=True) +class DenseModelDescription(BaseModelDescription): + dim: Optional[int] = None + tasks: Optional[dict[str, Any]] = None + + def __post_init__(self) -> None: + assert self.dim is not None, "dim is required for dense model description" + + +@dataclass(frozen=True) +class SparseModelDescription(BaseModelDescription): + requires_idf: Optional[bool] = None + vocab_size: Optional[int] = None diff --git a/fastembed/common/model_management.py b/fastembed/common/model_management.py index f05eaede..c16c4670 100644 --- a/fastembed/common/model_management.py +++ b/fastembed/common/model_management.py @@ -4,7 +4,7 @@ import shutil import tarfile from pathlib import Path -from typing import Any, Optional, Union +from typing import Any, Optional, Union, TypeVar, Generic import requests from huggingface_hub import snapshot_download, model_info, list_repo_tree @@ -16,9 +16,12 @@ ) from loguru import logger from tqdm import tqdm +from fastembed.common.model_description import BaseModelDescription +T = TypeVar("T", bound=BaseModelDescription) -class ModelManagement: + +class ModelManagement(Generic[T]): METADATA_FILE = "files_metadata.json" @classmethod @@ -26,12 +29,16 @@ def list_supported_models(cls) -> list[dict[str, Any]]: """Lists the supported models. Returns: - list[dict[str, Any]]: A list of dictionaries containing the model information. + list[T]: A list of dictionaries containing the model information. """ raise NotImplementedError() @classmethod - def _get_model_description(cls, model_name: str) -> dict[str, Any]: + def _list_supported_models(cls) -> list[T]: + raise NotImplementedError() + + @classmethod + def _get_model_description(cls, model_name: str) -> T: """ Gets the model description from the model_name. @@ -42,10 +49,10 @@ def _get_model_description(cls, model_name: str) -> dict[str, Any]: ValueError: If the model_name is not supported. Returns: - dict[str, Any]: The model description. + T: The model description. """ - for model in cls.list_supported_models(): - if model_name.lower() == model["model"].lower(): + for model in cls._list_supported_models(): + if model_name.lower() == model.model.lower(): return model raise ValueError(f"Model {model_name} is not supported in {cls.__name__}.") @@ -160,7 +167,9 @@ def _collect_file_metadata( } return meta - def _save_file_metadata(model_dir: Path, meta: dict[str, dict[str, Union[int, str]]]) -> None: + def _save_file_metadata( + model_dir: Path, meta: dict[str, dict[str, Union[int, str]]] + ) -> None: try: if not model_dir.exists(): model_dir.mkdir(parents=True, exist_ok=True) @@ -292,7 +301,11 @@ def decompress_to_cache(cls, targz_path: str, cache_dir: str) -> str: @classmethod def retrieve_model_gcs( - cls, model_name: str, source_url: str, cache_dir: str, local_files_only: bool = False + cls, + model_name: str, + source_url: str, + cache_dir: str, + local_files_only: bool = False, ) -> Path: fast_model_name = f"fast-{model_name.split('/')[-1]}" cache_tmp_dir = Path(cache_dir) / "tmp" @@ -336,14 +349,12 @@ def retrieve_model_gcs( return model_dir @classmethod - def download_model( - cls, model: dict[str, Any], cache_dir: str, retries: int = 3, **kwargs: Any - ) -> Path: + def download_model(cls, model: T, cache_dir: str, retries: int = 3, **kwargs: Any) -> Path: """ Downloads a model from HuggingFace Hub or Google Cloud Storage. Args: - model (dict[str, Any]): The model description. + model (T): The model description. Example: ``` { @@ -368,16 +379,16 @@ def download_model( if specific_model_path: return Path(specific_model_path) retries = 1 if local_files_only else retries - hf_source = model.get("sources", {}).get("hf") - url_source = model.get("sources", {}).get("url") + hf_source = model.sources.hf + url_source = model.sources.url sleep = 3.0 while retries > 0: retries -= 1 if hf_source: - extra_patterns = [model["model_file"]] - extra_patterns.extend(model.get("additional_files", [])) + extra_patterns = [model.model_file] + extra_patterns.extend(model.additional_files) try: return Path( @@ -399,8 +410,8 @@ def download_model( if url_source or local_files_only: try: return cls.retrieve_model_gcs( - model["model"], - url_source, + model.model, + str(url_source), str(cache_dir), local_files_only=local_files_only, ) @@ -417,4 +428,4 @@ def download_model( time.sleep(sleep) sleep *= 3 - raise ValueError(f"Could not load model {model['model']} from any source.") + raise ValueError(f"Could not load model {model.model} from any source.") diff --git a/fastembed/image/image_embedding.py b/fastembed/image/image_embedding.py index 7c4140de..e5d9f17e 100644 --- a/fastembed/image/image_embedding.py +++ b/fastembed/image/image_embedding.py @@ -1,9 +1,11 @@ from typing import Any, Iterable, Optional, Sequence, Type, Union +from dataclasses import asdict from fastembed.common.types import NumpyArray from fastembed.common import ImageInput, OnnxProvider from fastembed.image.image_embedding_base import ImageEmbeddingBase from fastembed.image.onnx_embedding import OnnxImageEmbedding +from fastembed.common.model_description import DenseModelDescription class ImageEmbedding(ImageEmbeddingBase): @@ -34,9 +36,13 @@ def list_supported_models(cls) -> list[dict[str, Any]]: ] ``` """ - result: list[dict[str, Any]] = [] + return [asdict(model) for model in cls._list_supported_models()] + + @classmethod + def _list_supported_models(cls) -> list[DenseModelDescription]: + result: list[DenseModelDescription] = [] for embedding in cls.EMBEDDINGS_REGISTRY: - result.extend(embedding.list_supported_models()) + result.extend(embedding._list_supported_models()) return result def __init__( @@ -52,8 +58,8 @@ def __init__( ): super().__init__(model_name, cache_dir, threads, **kwargs) for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY: - supported_models = EMBEDDING_MODEL_TYPE.list_supported_models() - if any(model_name.lower() == model["model"].lower() for model in supported_models): + supported_models = EMBEDDING_MODEL_TYPE._list_supported_models() + if any(model_name.lower() == model.model.lower() for model in supported_models): self.model = EMBEDDING_MODEL_TYPE( model_name, cache_dir, diff --git a/fastembed/image/image_embedding_base.py b/fastembed/image/image_embedding_base.py index 17328632..35485cc0 100644 --- a/fastembed/image/image_embedding_base.py +++ b/fastembed/image/image_embedding_base.py @@ -1,11 +1,12 @@ from typing import Iterable, Optional, Any, Union +from fastembed.common.model_description import DenseModelDescription from fastembed.common.types import NumpyArray from fastembed.common.model_management import ModelManagement from fastembed.common.types import ImageInput -class ImageEmbeddingBase(ModelManagement): +class ImageEmbeddingBase(ModelManagement[DenseModelDescription]): def __init__( self, model_name: str, diff --git a/fastembed/image/onnx_embedding.py b/fastembed/image/onnx_embedding.py index 08b6e1c2..9db22fa0 100644 --- a/fastembed/image/onnx_embedding.py +++ b/fastembed/image/onnx_embedding.py @@ -9,62 +9,54 @@ from fastembed.image.image_embedding_base import ImageEmbeddingBase from fastembed.image.onnx_image_model import ImageEmbeddingWorker, OnnxImageModel -supported_onnx_models = [ - { - "model": "Qdrant/clip-ViT-B-32-vision", - "dim": 512, - "description": "Image embeddings, Multimodal (text&image), 2021 year", - "license": "mit", - "size_in_GB": 0.34, - "sources": { - "hf": "Qdrant/clip-ViT-B-32-vision", - }, - "model_file": "model.onnx", - }, - { - "model": "Qdrant/resnet50-onnx", - "dim": 2048, - "description": "Image embeddings, Unimodal (image), 2016 year", - "license": "apache-2.0", - "size_in_GB": 0.1, - "sources": { - "hf": "Qdrant/resnet50-onnx", - }, - "model_file": "model.onnx", - }, - { - "model": "Qdrant/Unicom-ViT-B-16", - "dim": 768, - "description": "Image embeddings (more detailed than Unicom-ViT-B-32), Multimodal (text&image), 2023 year", - "license": "apache-2.0", - "size_in_GB": 0.82, - "sources": { - "hf": "Qdrant/Unicom-ViT-B-16", - }, - "model_file": "model.onnx", - }, - { - "model": "Qdrant/Unicom-ViT-B-32", - "dim": 512, - "description": "Image embeddings, Multimodal (text&image), 2023 year", - "license": "apache-2.0", - "size_in_GB": 0.48, - "sources": { - "hf": "Qdrant/Unicom-ViT-B-32", - }, - "model_file": "model.onnx", - }, - { - "model": "jinaai/jina-clip-v1", - "dim": 768, - "description": "Image embeddings, Multimodal (text&image), 2024 year", - "license": "apache-2.0", - "size_in_GB": 0.34, - "sources": { - "hf": "jinaai/jina-clip-v1", - }, - "model_file": "onnx/vision_model.onnx", - }, +from fastembed.common.model_description import DenseModelDescription, ModelSource + +supported_onnx_models: list[DenseModelDescription] = [ + DenseModelDescription( + model="Qdrant/clip-ViT-B-32-vision", + dim=512, + description="Image embeddings, Multimodal (text&image), 2021 year", + license="mit", + size_in_GB=0.34, + sources=ModelSource(hf="Qdrant/clip-ViT-B-32-vision"), + model_file="model.onnx", + ), + DenseModelDescription( + model="Qdrant/resnet50-onnx", + dim=2048, + description="Image embeddings, Unimodal (image), 2016 year", + license="apache-2.0", + size_in_GB=0.1, + sources=ModelSource(hf="Qdrant/resnet50-onnx"), + model_file="model.onnx", + ), + DenseModelDescription( + model="Qdrant/Unicom-ViT-B-16", + dim=768, + description="Image embeddings (more detailed than Unicom-ViT-B-32), Multimodal (text&image), 2023 year", + license="apache-2.0", + size_in_GB=0.82, + sources=ModelSource(hf="Qdrant/Unicom-ViT-B-16"), + model_file="model.onnx", + ), + DenseModelDescription( + model="Qdrant/Unicom-ViT-B-32", + dim=512, + description="Image embeddings, Multimodal (text&image), 2023 year", + license="apache-2.0", + size_in_GB=0.48, + sources=ModelSource(hf="Qdrant/Unicom-ViT-B-32"), + model_file="model.onnx", + ), + DenseModelDescription( + model="jinaai/jina-clip-v1", + dim=768, + description="Image embeddings, Multimodal (text&image), 2024 year", + license="apache-2.0", + size_in_GB=0.34, + sources=ModelSource(hf="jinaai/jina-clip-v1"), + model_file="onnx/vision_model.onnx", + ), ] @@ -137,7 +129,7 @@ def load_onnx_model(self) -> None: """ self._load_onnx_model( model_dir=self._model_dir, - model_file=self.model_description["model_file"], + model_file=self.model_description.model_file, threads=self.threads, providers=self.providers, cuda=self.cuda, @@ -145,12 +137,12 @@ def load_onnx_model(self) -> None: ) @classmethod - def list_supported_models(cls) -> list[dict[str, Any]]: + def _list_supported_models(cls) -> list[DenseModelDescription]: """ Lists the supported models. Returns: - list[dict[str, Any]]: A list of dictionaries containing the model information. + list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information. """ return supported_onnx_models diff --git a/fastembed/late_interaction/colbert.py b/fastembed/late_interaction/colbert.py index c544774a..841d3a73 100644 --- a/fastembed/late_interaction/colbert.py +++ b/fastembed/late_interaction/colbert.py @@ -12,31 +12,27 @@ LateInteractionTextEmbeddingBase, ) from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker - - -supported_colbert_models = [ - { - "model": "colbert-ir/colbertv2.0", - "dim": 128, - "description": "Late interaction model", - "license": "mit", - "size_in_GB": 0.44, - "sources": { - "hf": "colbert-ir/colbertv2.0", - }, - "model_file": "model.onnx", - }, - { - "model": "answerdotai/answerai-colbert-small-v1", - "dim": 96, - "description": "Text embeddings, Unimodal (text), Multilingual (~100 languages), 512 input tokens truncation, 2024 year", - "license": "apache-2.0", - "size_in_GB": 0.13, - "sources": { - "hf": "answerdotai/answerai-colbert-small-v1", - }, - "model_file": "vespa_colbert.onnx", - }, +from fastembed.common.model_description import DenseModelDescription, ModelSource + +supported_colbert_models: list[DenseModelDescription] = [ + DenseModelDescription( + model="colbert-ir/colbertv2.0", + dim=128, + description="Late interaction model", + license="mit", + size_in_GB=0.44, + sources=ModelSource(hf="colbert-ir/colbertv2.0"), + model_file="model.onnx", + ), + DenseModelDescription( + model="answerdotai/answerai-colbert-small-v1", + dim=96, + description="Text embeddings, Unimodal (text), Multilingual (~100 languages), 512 input tokens truncation, 2024 year", + license="apache-2.0", + size_in_GB=0.13, + sources=ModelSource(hf="answerdotai/answerai-colbert-small-v1"), + model_file="vespa_colbert.onnx", + ), ] @@ -112,11 +108,11 @@ def _tokenize_documents(self, documents: list[str]) -> list[Encoding]: return encoded @classmethod - def list_supported_models(cls) -> list[dict[str, Any]]: + def _list_supported_models(cls) -> list[DenseModelDescription]: """Lists the supported models. Returns: - list[dict[str, Any]]: A list of dictionaries containing the model information. + list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information. """ return supported_colbert_models @@ -189,7 +185,7 @@ def __init__( def load_onnx_model(self) -> None: self._load_onnx_model( model_dir=self._model_dir, - model_file=self.model_description["model_file"], + model_file=self.model_description.model_file, threads=self.threads, providers=self.providers, cuda=self.cuda, diff --git a/fastembed/late_interaction/jina_colbert.py b/fastembed/late_interaction/jina_colbert.py index 402d660d..3ef89c63 100644 --- a/fastembed/late_interaction/jina_colbert.py +++ b/fastembed/late_interaction/jina_colbert.py @@ -2,21 +2,19 @@ from fastembed.common.types import NumpyArray from fastembed.late_interaction.colbert import Colbert, ColbertEmbeddingWorker - - -supported_jina_colbert_models = [ - { - "model": "jinaai/jina-colbert-v2", - "dim": 128, - "description": "New model that expands capabilities of colbert-v1 with multilingual and context length of 8192, 2024 year", - "license": "cc-by-nc-4.0", - "size_in_GB": 2.24, - "sources": { - "hf": "jinaai/jina-colbert-v2", - }, - "model_file": "onnx/model.onnx", - "additional_files": ["onnx/model.onnx_data"], - }, +from fastembed.common.model_description import DenseModelDescription, ModelSource + +supported_jina_colbert_models: list[DenseModelDescription] = [ + DenseModelDescription( + model="jinaai/jina-colbert-v2", + dim=128, + description="New model that expands capabilities of colbert-v1 with multilingual and context length of 8192, 2024 year", + license="cc-by-nc-4.0", + size_in_GB=2.24, + sources=ModelSource(hf="jinaai/jina-colbert-v2"), + model_file="onnx/model.onnx", + additional_files=["onnx/model.onnx_data"], + ) ] @@ -31,11 +29,11 @@ def _get_worker_class(cls) -> Type[ColbertEmbeddingWorker]: return JinaColbertEmbeddingWorker @classmethod - def list_supported_models(cls) -> list[dict[str, Any]]: + def _list_supported_models(cls) -> list[DenseModelDescription]: """Lists the supported models. Returns: - list[dict[str, Any]]: A list of dictionaries containing the model information. + list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information. """ return supported_jina_colbert_models diff --git a/fastembed/late_interaction/late_interaction_embedding_base.py b/fastembed/late_interaction/late_interaction_embedding_base.py index 20dcaa72..af635f17 100644 --- a/fastembed/late_interaction/late_interaction_embedding_base.py +++ b/fastembed/late_interaction/late_interaction_embedding_base.py @@ -1,10 +1,11 @@ from typing import Iterable, Optional, Union, Any +from fastembed.common.model_description import DenseModelDescription from fastembed.common.types import NumpyArray from fastembed.common.model_management import ModelManagement -class LateInteractionTextEmbeddingBase(ModelManagement): +class LateInteractionTextEmbeddingBase(ModelManagement[DenseModelDescription]): def __init__( self, model_name: str, diff --git a/fastembed/late_interaction/late_interaction_text_embedding.py b/fastembed/late_interaction/late_interaction_text_embedding.py index 4405b1d5..8f6aa835 100644 --- a/fastembed/late_interaction/late_interaction_text_embedding.py +++ b/fastembed/late_interaction/late_interaction_text_embedding.py @@ -1,5 +1,7 @@ from typing import Any, Iterable, Optional, Sequence, Type, Union +from dataclasses import asdict +from fastembed.common.model_description import DenseModelDescription from fastembed.common.types import NumpyArray from fastembed.common import OnnxProvider from fastembed.late_interaction.colbert import Colbert @@ -37,9 +39,13 @@ def list_supported_models(cls) -> list[dict[str, Any]]: ] ``` """ - result: list[dict[str, Any]] = [] + return [asdict(model) for model in cls._list_supported_models()] + + @classmethod + def _list_supported_models(cls) -> list[DenseModelDescription]: + result: list[DenseModelDescription] = [] for embedding in cls.EMBEDDINGS_REGISTRY: - result.extend(embedding.list_supported_models()) + result.extend(embedding._list_supported_models()) return result def __init__( @@ -55,8 +61,8 @@ def __init__( ): super().__init__(model_name, cache_dir, threads, **kwargs) for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY: - supported_models = EMBEDDING_MODEL_TYPE.list_supported_models() - if any(model_name.lower() == model["model"].lower() for model in supported_models): + supported_models = EMBEDDING_MODEL_TYPE._list_supported_models() + if any(model_name.lower() == model.model.lower() for model in supported_models): self.model = EMBEDDING_MODEL_TYPE( model_name, cache_dir, diff --git a/fastembed/late_interaction_multimodal/colpali.py b/fastembed/late_interaction_multimodal/colpali.py index 053ecacb..731c902b 100644 --- a/fastembed/late_interaction_multimodal/colpali.py +++ b/fastembed/late_interaction_multimodal/colpali.py @@ -15,22 +15,19 @@ TextEmbeddingWorker, ImageEmbeddingWorker, ) - -supported_colpali_models = [ - { - "model": "Qdrant/colpali-v1.3-fp16", - "dim": 128, - "description": "Text embeddings, Multimodal (text&image), English, 50 tokens query length truncation, 2024.", - "license": "mit", - "size_in_GB": 6.5, - "sources": { - "hf": "Qdrant/colpali-v1.3-fp16", - }, - "additional_files": [ - "model.onnx_data", - ], - "model_file": "model.onnx", - }, +from fastembed.common.model_description import DenseModelDescription, ModelSource + +supported_colpali_models: list[DenseModelDescription] = [ + DenseModelDescription( + model="Qdrant/colpali-v1.3-fp16", + dim=128, + description="Text embeddings, Multimodal (text&image), English, 50 tokens query length truncation, 2024.", + license="mit", + size_in_GB=6.5, + sources=ModelSource(hf="Qdrant/colpali-v1.3-fp16"), + additional_files=["model.onnx_data"], + model_file="model.onnx", + ), ] @@ -111,18 +108,18 @@ def __init__( self.load_onnx_model() @classmethod - def list_supported_models(cls) -> list[dict[str, Any]]: + def _list_supported_models(cls) -> list[DenseModelDescription]: """Lists the supported models. Returns: - list[dict[str, Any]]: A list of dictionaries containing the model information. + list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information. """ return supported_colpali_models def load_onnx_model(self) -> None: self._load_onnx_model( model_dir=self._model_dir, - model_file=self.model_description["model_file"], + model_file=self.model_description.model_file, threads=self.threads, providers=self.providers, cuda=self.cuda, @@ -142,8 +139,9 @@ def _post_process_onnx_image_output( Returns: Iterable[NumpyArray]: Post-processed output as NumPy arrays. """ + assert self.model_description.dim is not None, "Model dim is not defined" return output.model_output.reshape( - output.model_output.shape[0], -1, self.model_description["dim"] + output.model_output.shape[0], -1, self.model_description.dim ).astype(np.float32) def _post_process_onnx_text_output( diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py index 08819a53..e7c0beb5 100644 --- a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py @@ -1,4 +1,5 @@ from typing import Any, Iterable, Optional, Sequence, Type, Union +from dataclasses import asdict from fastembed.common import OnnxProvider, ImageInput from fastembed.common.types import NumpyArray @@ -7,6 +8,7 @@ from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import ( LateInteractionMultimodalEmbeddingBase, ) +from fastembed.common.model_description import DenseModelDescription class LateInteractionMultimodalEmbedding(LateInteractionMultimodalEmbeddingBase): @@ -24,25 +26,29 @@ def list_supported_models(cls) -> list[dict[str, Any]]: ``` [ { - "model": "AndrewOgn/colpali-v1.3-merged-onnx", + "model": "Qdrant/colpali-v1.3-fp16", "dim": 128, "description": "Text embeddings, Unimodal (text), Aligned to image latent space, ColBERT-compatible, 512 tokens max, 2024.", "license": "mit", "size_in_GB": 6.06, "sources": { - "hf": "AndrewOgn/colpali-v1.3-merged-onnx", - }, + "hf": "Qdrant/colpali-v1.3-fp16", + }, "additional_files": [ - "model.onnx_data", - ], - "model_file": "model.onnx", + "model.onnx_data", + ], + "model_file": "model.onnx", }, ] ``` """ - result: list[dict[str, Any]] = [] + return [asdict(model) for model in cls._list_supported_models()] + + @classmethod + def _list_supported_models(cls) -> list[DenseModelDescription]: + result: list[DenseModelDescription] = [] for embedding in cls.EMBEDDINGS_REGISTRY: - result.extend(embedding.list_supported_models()) + result.extend(embedding._list_supported_models()) return result def __init__( @@ -58,8 +64,8 @@ def __init__( ): super().__init__(model_name, cache_dir, threads, **kwargs) for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY: - supported_models = EMBEDDING_MODEL_TYPE.list_supported_models() - if any(model_name.lower() == model["model"].lower() for model in supported_models): + supported_models = EMBEDDING_MODEL_TYPE._list_supported_models() + if any(model_name.lower() == model.model.lower() for model in supported_models): self.model = EMBEDDING_MODEL_TYPE( model_name, cache_dir, diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py index 64ee8643..193ff7c5 100644 --- a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py @@ -2,11 +2,12 @@ from fastembed.common import ImageInput +from fastembed.common.model_description import DenseModelDescription from fastembed.common.model_management import ModelManagement from fastembed.common.types import NumpyArray -class LateInteractionMultimodalEmbeddingBase(ModelManagement): +class LateInteractionMultimodalEmbeddingBase(ModelManagement[DenseModelDescription]): def __init__( self, model_name: str, diff --git a/fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py b/fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py index 461b5837..af4b4e69 100644 --- a/fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +++ b/fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py @@ -10,78 +10,67 @@ TextRerankerWorker, ) from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase - -supported_onnx_models = [ - { - "model": "Xenova/ms-marco-MiniLM-L-6-v2", - "size_in_GB": 0.08, - "sources": { - "hf": "Xenova/ms-marco-MiniLM-L-6-v2", - }, - "model_file": "onnx/model.onnx", - "description": "MiniLM-L-6-v2 model optimized for re-ranking tasks.", - "license": "apache-2.0", - }, - { - "model": "Xenova/ms-marco-MiniLM-L-12-v2", - "size_in_GB": 0.12, - "sources": { - "hf": "Xenova/ms-marco-MiniLM-L-12-v2", - }, - "model_file": "onnx/model.onnx", - "description": "MiniLM-L-12-v2 model optimized for re-ranking tasks.", - "license": "apache-2.0", - }, - { - "model": "BAAI/bge-reranker-base", - "size_in_GB": 1.04, - "sources": { - "hf": "BAAI/bge-reranker-base", - }, - "model_file": "onnx/model.onnx", - "description": "BGE reranker base model for cross-encoder re-ranking.", - "license": "mit", - }, - { - "model": "jinaai/jina-reranker-v1-tiny-en", - "size_in_GB": 0.13, - "sources": { - "hf": "jinaai/jina-reranker-v1-tiny-en", - }, - "model_file": "onnx/model.onnx", - "description": "Designed for blazing-fast re-ranking with 8K context length and fewer parameters than jina-reranker-v1-turbo-en.", - "license": "apache-2.0", - }, - { - "model": "jinaai/jina-reranker-v1-turbo-en", - "size_in_GB": 0.15, - "sources": { - "hf": "jinaai/jina-reranker-v1-turbo-en", - }, - "model_file": "onnx/model.onnx", - "description": "Designed for blazing-fast re-ranking with 8K context length.", - "license": "apache-2.0", - }, - { - "model": "jinaai/jina-reranker-v2-base-multilingual", - "size_in_GB": 1.11, - "sources": { - "hf": "jinaai/jina-reranker-v2-base-multilingual", - }, - "model_file": "onnx/model.onnx", - "description": "A multi-lingual reranker model for cross-encoder re-ranking with 1K context length and sliding window", - "license": "cc-by-nc-4.0", - }, +from fastembed.common.model_description import BaseModelDescription, ModelSource + +supported_onnx_models: list[BaseModelDescription] = [ + BaseModelDescription( + model="Xenova/ms-marco-MiniLM-L-6-v2", + description="MiniLM-L-6-v2 model optimized for re-ranking tasks.", + license="apache-2.0", + size_in_GB=0.08, + sources=ModelSource(hf="Xenova/ms-marco-MiniLM-L-6-v2"), + model_file="onnx/model.onnx", + ), + BaseModelDescription( + model="Xenova/ms-marco-MiniLM-L-12-v2", + description="MiniLM-L-12-v2 model optimized for re-ranking tasks.", + license="apache-2.0", + size_in_GB=0.12, + sources=ModelSource(hf="Xenova/ms-marco-MiniLM-L-12-v2"), + model_file="onnx/model.onnx", + ), + BaseModelDescription( + model="BAAI/bge-reranker-base", + description="BGE reranker base model for cross-encoder re-ranking.", + license="mit", + size_in_GB=1.04, + sources=ModelSource(hf="BAAI/bge-reranker-base"), + model_file="onnx/model.onnx", + ), + BaseModelDescription( + model="jinaai/jina-reranker-v1-tiny-en", + description="Designed for blazing-fast re-ranking with 8K context length and fewer parameters than jina-reranker-v1-turbo-en.", + license="apache-2.0", + size_in_GB=0.13, + sources=ModelSource(hf="jinaai/jina-reranker-v1-tiny-en"), + model_file="onnx/model.onnx", + ), + BaseModelDescription( + model="jinaai/jina-reranker-v1-turbo-en", + description="Designed for blazing-fast re-ranking with 8K context length.", + license="apache-2.0", + size_in_GB=0.15, + sources=ModelSource(hf="jinaai/jina-reranker-v1-turbo-en"), + model_file="onnx/model.onnx", + ), + BaseModelDescription( + model="jinaai/jina-reranker-v2-base-multilingual", + description="A multi-lingual reranker model for cross-encoder re-ranking with 1K context length and sliding window", + license="cc-by-nc-4.0", + size_in_GB=1.11, + sources=ModelSource(hf="jinaai/jina-reranker-v2-base-multilingual"), + model_file="onnx/model.onnx", + ), ] class OnnxTextCrossEncoder(TextCrossEncoderBase, OnnxCrossEncoderModel): @classmethod - def list_supported_models(cls) -> list[dict[str, Any]]: + def _list_supported_models(cls) -> list[BaseModelDescription]: """Lists the supported models. Returns: - list[dict[str, Any]]: A list of dictionaries containing the model information. + list[BaseModelDescription]: A list of BaseModelDescription objects containing the model information. """ return supported_onnx_models @@ -155,7 +144,7 @@ def __init__( def load_onnx_model(self) -> None: self._load_onnx_model( model_dir=self._model_dir, - model_file=self.model_description["model_file"], + model_file=self.model_description.model_file, threads=self.threads, providers=self.providers, cuda=self.cuda, diff --git a/fastembed/rerank/cross_encoder/text_cross_encoder.py b/fastembed/rerank/cross_encoder/text_cross_encoder.py index 614ab69d..573053e0 100644 --- a/fastembed/rerank/cross_encoder/text_cross_encoder.py +++ b/fastembed/rerank/cross_encoder/text_cross_encoder.py @@ -1,8 +1,10 @@ from typing import Any, Iterable, Optional, Sequence, Type +from dataclasses import asdict from fastembed.common import OnnxProvider from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import OnnxTextCrossEncoder from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase +from fastembed.common.model_description import BaseModelDescription class TextCrossEncoder(TextCrossEncoderBase): @@ -15,7 +17,7 @@ def list_supported_models(cls) -> list[dict[str, Any]]: """Lists the supported models. Returns: - list[dict[str, Any]]: A list of dictionaries containing the model information. + list[BaseModelDescription]: A list of dictionaries containing the model information. Example: ``` @@ -33,9 +35,13 @@ def list_supported_models(cls) -> list[dict[str, Any]]: ] ``` """ - result: list[dict[str, Any]] = [] + return [asdict(model) for model in cls._list_supported_models()] + + @classmethod + def _list_supported_models(cls) -> list[BaseModelDescription]: + result: list[BaseModelDescription] = [] for encoder in cls.CROSS_ENCODER_REGISTRY: - result.extend(encoder.list_supported_models()) + result.extend(encoder._list_supported_models()) return result def __init__( @@ -52,8 +58,8 @@ def __init__( super().__init__(model_name, cache_dir, threads, **kwargs) for CROSS_ENCODER_TYPE in self.CROSS_ENCODER_REGISTRY: - supported_models = CROSS_ENCODER_TYPE.list_supported_models() - if any(model_name.lower() == model["model"].lower() for model in supported_models): + supported_models = CROSS_ENCODER_TYPE._list_supported_models() + if any(model_name.lower() == model.model.lower() for model in supported_models): self.model = CROSS_ENCODER_TYPE( model_name=model_name, cache_dir=cache_dir, diff --git a/fastembed/rerank/cross_encoder/text_cross_encoder_base.py b/fastembed/rerank/cross_encoder/text_cross_encoder_base.py index cadf6d0b..84b44e41 100644 --- a/fastembed/rerank/cross_encoder/text_cross_encoder_base.py +++ b/fastembed/rerank/cross_encoder/text_cross_encoder_base.py @@ -1,9 +1,10 @@ from typing import Any, Iterable, Optional +from fastembed.common.model_description import BaseModelDescription from fastembed.common.model_management import ModelManagement -class TextCrossEncoderBase(ModelManagement): +class TextCrossEncoderBase(ModelManagement[BaseModelDescription]): def __init__( self, model_name: str, diff --git a/fastembed/sparse/bm25.py b/fastembed/sparse/bm25.py index 1a21f4ee..bd2b43ee 100644 --- a/fastembed/sparse/bm25.py +++ b/fastembed/sparse/bm25.py @@ -19,6 +19,7 @@ SparseTextEmbeddingBase, ) from fastembed.sparse.utils.tokenizer import SimpleTokenizer +from fastembed.common.model_description import SparseModelDescription, ModelSource supported_languages = [ "arabic", @@ -52,19 +53,18 @@ "turkish", ] -supported_bm25_models = [ - { - "model": "Qdrant/bm25", - "description": "BM25 as sparse embeddings meant to be used with Qdrant", - "license": "apache-2.0", - "size_in_GB": 0.01, - "sources": { - "hf": "Qdrant/bm25", - }, - "model_file": "mock.file", # bm25 does not require a model, so we just use a mock - "additional_files": [f"{lang}.txt" for lang in supported_languages], - "requires_idf": True, - }, +supported_bm25_models: list[SparseModelDescription] = [ + SparseModelDescription( + model="Qdrant/bm25", + vocab_size=0, + description="BM25 as sparse embeddings meant to be used with Qdrant", + license="apache-2.0", + size_in_GB=0.01, + sources=ModelSource(hf="Qdrant/bm25"), + additional_files=[f"{lang}.txt" for lang in supported_languages], + requires_idf=True, + model_file="mock.file", + ), ] @@ -146,11 +146,11 @@ def __init__( self.tokenizer = SimpleTokenizer @classmethod - def list_supported_models(cls) -> list[dict[str, Any]]: + def _list_supported_models(cls) -> list[SparseModelDescription]: """Lists the supported models. Returns: - list[dict[str, Any]]: A list of dictionaries containing the model information. + list[SparseModelDescription]: A list of SparseModelDescription objects containing the model information. """ return supported_bm25_models diff --git a/fastembed/sparse/bm42.py b/fastembed/sparse/bm42.py index f34abb29..05ff2df7 100644 --- a/fastembed/sparse/bm42.py +++ b/fastembed/sparse/bm42.py @@ -15,21 +15,20 @@ SparseTextEmbeddingBase, ) from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker - -supported_bm42_models = [ - { - "model": "Qdrant/bm42-all-minilm-l6-v2-attentions", - "vocab_size": 30522, - "description": "Light sparse embedding model, which assigns an importance score to each token in the text", - "license": "apache-2.0", - "size_in_GB": 0.09, - "sources": { - "hf": "Qdrant/all_miniLM_L6_v2_with_attentions", - }, - "model_file": "model.onnx", - "additional_files": ["stopwords.txt"], - "requires_idf": True, - }, +from fastembed.common.model_description import SparseModelDescription, ModelSource + +supported_bm42_models: list[SparseModelDescription] = [ + SparseModelDescription( + model="Qdrant/bm42-all-minilm-l6-v2-attentions", + vocab_size=30522, + description="Light sparse embedding model, which assigns an importance score to each token in the text", + license="apache-2.0", + size_in_GB=0.09, + sources=ModelSource(hf="Qdrant/all_miniLM_L6_v2_with_attentions"), + model_file="model.onnx", + additional_files=["stopwords.txt"], + requires_idf=True, + ), ] MODEL_TO_LANGUAGE = { @@ -133,7 +132,7 @@ def __init__( def load_onnx_model(self) -> None: self._load_onnx_model( model_dir=self._model_dir, - model_file=self.model_description["model_file"], + model_file=self.model_description.model_file, threads=self.threads, providers=self.providers, cuda=self.cuda, @@ -251,11 +250,11 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Spars yield SparseEmbedding.from_dict(rescored) @classmethod - def list_supported_models(cls) -> list[dict[str, Any]]: + def _list_supported_models(cls) -> list[SparseModelDescription]: """Lists the supported models. Returns: - list[dict[str, Any]]: A list of dictionaries containing the model information. + list[SparseModelDescription]: A list of SparseModelDescription objects containing the model information. """ return supported_bm42_models diff --git a/fastembed/sparse/sparse_embedding_base.py b/fastembed/sparse/sparse_embedding_base.py index c6dc3393..b153c814 100644 --- a/fastembed/sparse/sparse_embedding_base.py +++ b/fastembed/sparse/sparse_embedding_base.py @@ -4,6 +4,7 @@ import numpy as np from numpy.typing import NDArray +from fastembed.common.model_description import SparseModelDescription from fastembed.common.types import NumpyArray from fastembed.common.model_management import ModelManagement @@ -30,7 +31,7 @@ def from_dict(cls, data: dict[int, float]) -> "SparseEmbedding": return cls(values=np.array(values), indices=np.array(indices)) -class SparseTextEmbeddingBase(ModelManagement): +class SparseTextEmbeddingBase(ModelManagement[SparseModelDescription]): def __init__( self, model_name: str, diff --git a/fastembed/sparse/sparse_text_embedding.py b/fastembed/sparse/sparse_text_embedding.py index 3447dcae..1e98f1d4 100644 --- a/fastembed/sparse/sparse_text_embedding.py +++ b/fastembed/sparse/sparse_text_embedding.py @@ -1,4 +1,5 @@ from typing import Any, Iterable, Optional, Sequence, Type, Union +from dataclasses import asdict from fastembed.common import OnnxProvider from fastembed.sparse.bm25 import Bm25 @@ -9,6 +10,7 @@ ) from fastembed.sparse.splade_pp import SpladePP import warnings +from fastembed.common.model_description import SparseModelDescription class SparseTextEmbedding(SparseTextEmbeddingBase): @@ -38,9 +40,13 @@ def list_supported_models(cls) -> list[dict[str, Any]]: ] ``` """ - result: list[dict[str, Any]] = [] + return [asdict(model) for model in cls._list_supported_models()] + + @classmethod + def _list_supported_models(cls) -> list[SparseModelDescription]: + result: list[SparseModelDescription] = [] for embedding in cls.EMBEDDINGS_REGISTRY: - result.extend(embedding.list_supported_models()) + result.extend(embedding._list_supported_models()) return result def __init__( @@ -65,8 +71,8 @@ def __init__( model_name = "prithivida/Splade_PP_en_v1" for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY: - supported_models = EMBEDDING_MODEL_TYPE.list_supported_models() - if any(model_name.lower() == model["model"].lower() for model in supported_models): + supported_models = EMBEDDING_MODEL_TYPE._list_supported_models() + if any(model_name.lower() == model.model.lower() for model in supported_models): self.model = EMBEDDING_MODEL_TYPE( model_name, cache_dir, diff --git a/fastembed/sparse/splade_pp.py b/fastembed/sparse/splade_pp.py index 82eb001e..ba486426 100644 --- a/fastembed/sparse/splade_pp.py +++ b/fastembed/sparse/splade_pp.py @@ -9,30 +9,27 @@ SparseTextEmbeddingBase, ) from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker - -supported_splade_models = [ - { - "model": "prithivida/Splade_PP_en_v1", - "vocab_size": 30522, - "description": "Independent Implementation of SPLADE++ Model for English.", - "license": "apache-2.0", - "size_in_GB": 0.532, - "sources": { - "hf": "Qdrant/SPLADE_PP_en_v1", - }, - "model_file": "model.onnx", - }, - { - "model": "prithvida/Splade_PP_en_v1", - "vocab_size": 30522, - "description": "Independent Implementation of SPLADE++ Model for English.", - "license": "apache-2.0", - "size_in_GB": 0.532, - "sources": { - "hf": "Qdrant/SPLADE_PP_en_v1", - }, - "model_file": "model.onnx", - }, +from fastembed.common.model_description import SparseModelDescription, ModelSource + +supported_splade_models: list[SparseModelDescription] = [ + SparseModelDescription( + model="prithivida/Splade_PP_en_v1", + vocab_size=30522, + description="Independent Implementation of SPLADE++ Model for English.", + license="apache-2.0", + size_in_GB=0.532, + sources=ModelSource(hf="Qdrant/SPLADE_PP_en_v1"), + model_file="model.onnx", + ), + SparseModelDescription( + model="prithvida/Splade_PP_en_v1", + vocab_size=30522, + description="Independent Implementation of SPLADE++ Model for English.", + license="apache-2.0", + size_in_GB=0.532, + sources=ModelSource(hf="Qdrant/SPLADE_PP_en_v1"), + model_file="model.onnx", + ), ] @@ -55,11 +52,11 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Spars yield SparseEmbedding(values=scores, indices=indices) @classmethod - def list_supported_models(cls) -> list[dict[str, Any]]: + def _list_supported_models(cls) -> list[SparseModelDescription]: """Lists the supported models. Returns: - list[dict[str, Any]]: A list of dictionaries containing the model information. + list[SparseModelDescription]: A list of SparseModelDescription objects containing the model information. """ return supported_splade_models @@ -128,7 +125,7 @@ def __init__( def load_onnx_model(self) -> None: self._load_onnx_model( model_dir=self._model_dir, - model_file=self.model_description["model_file"], + model_file=self.model_description.model_file, threads=self.threads, providers=self.providers, cuda=self.cuda, diff --git a/fastembed/text/clip_embedding.py b/fastembed/text/clip_embedding.py index 504ecf57..686d3e7e 100644 --- a/fastembed/text/clip_embedding.py +++ b/fastembed/text/clip_embedding.py @@ -3,19 +3,21 @@ from fastembed.common.types import NumpyArray from fastembed.common.onnx_model import OnnxOutputContext from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker - -supported_clip_models = [ - { - "model": "Qdrant/clip-ViT-B-32-text", - "dim": 512, - "description": "Text embeddings, Multimodal (text&image), English, 77 input tokens truncation, Prefixes for queries/documents: not necessary, 2021 year", - "license": "mit", - "size_in_GB": 0.25, - "sources": { - "hf": "Qdrant/clip-ViT-B-32-text", - }, - "model_file": "model.onnx", - }, +from fastembed.common.model_description import DenseModelDescription, ModelSource + +supported_clip_models: list[DenseModelDescription] = [ + DenseModelDescription( + model="Qdrant/clip-ViT-B-32-text", + dim=512, + description=( + "Text embeddings, Multimodal (text&image), English, 77 input tokens truncation, " + "Prefixes for queries/documents: not necessary, 2021 year" + ), + license="mit", + size_in_GB=0.25, + sources=ModelSource(hf="Qdrant/clip-ViT-B-32-text"), + model_file="model.onnx", + ), ] @@ -25,11 +27,11 @@ def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]: return CLIPEmbeddingWorker @classmethod - def list_supported_models(cls) -> list[dict[str, Any]]: + def _list_supported_models(cls) -> list[DenseModelDescription]: """Lists the supported models. Returns: - list[dict[str, Any]]: A list of dictionaries containing the model information. + list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information. """ return supported_clip_models diff --git a/fastembed/text/multitask_embedding.py b/fastembed/text/multitask_embedding.py index 11645bab..a67a337a 100644 --- a/fastembed/text/multitask_embedding.py +++ b/fastembed/text/multitask_embedding.py @@ -6,27 +6,29 @@ from fastembed.common.types import NumpyArray from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding from fastembed.text.onnx_embedding import OnnxTextEmbeddingWorker +from fastembed.common.model_description import DenseModelDescription, ModelSource -supported_multitask_models = [ - { - "model": "jinaai/jina-embeddings-v3", - "dim": 1024, - "tasks": { +supported_multitask_models: list[DenseModelDescription] = [ + DenseModelDescription( + model="jinaai/jina-embeddings-v3", + dim=1024, + tasks={ "retrieval.query": 0, "retrieval.passage": 1, "separation": 2, "classification": 3, "text-matching": 4, }, - "description": "Multi-task unimodal (text) embedding model, multi-lingual (~100), 1024 tokens truncation, and 8192 sequence length. Prefixes for queries/documents: not necessary, 2024 year.", - "license": "cc-by-nc-4.0", - "size_in_GB": 2.29, - "sources": { - "hf": "jinaai/jina-embeddings-v3", - }, - "model_file": "onnx/model.onnx", - "additional_files": ["onnx/model.onnx_data"], - }, + description=( + "Multi-task unimodal (text) embedding model, multi-lingual (~100), " + "1024 tokens truncation, and 8192 sequence length. Prefixes for queries/documents: not necessary, 2024 year." + ), + license="cc-by-nc-4.0", + size_in_GB=2.29, + sources=ModelSource(hf="jinaai/jina-embeddings-v3"), + model_file="onnx/model.onnx", + additional_files=["onnx/model.onnx_data"], + ), ] @@ -51,7 +53,7 @@ def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]: return JinaEmbeddingV3Worker @classmethod - def list_supported_models(cls) -> list[dict[str, Any]]: + def _list_supported_models(cls) -> list[DenseModelDescription]: return supported_multitask_models def _preprocess_onnx_input( diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index c13c2733..bed3ca1e 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -1,172 +1,194 @@ from typing import Any, Iterable, Optional, Sequence, Type, Union import numpy as np - from fastembed.common.types import NumpyArray, OnnxProvider from fastembed.common.onnx_model import OnnxOutputContext from fastembed.common.utils import define_cache_dir, normalize from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker from fastembed.text.text_embedding_base import TextEmbeddingBase +from fastembed.common.model_description import DenseModelDescription, ModelSource -supported_onnx_models = [ - { - "model": "BAAI/bge-base-en", - "dim": 768, - "description": "Text embeddings, Unimodal (text), English, 512 input tokens truncation, Prefixes for queries/documents: necessary, 2023 year.", - "license": "mit", - "size_in_GB": 0.42, - "sources": { - "hf": "Qdrant/fast-bge-base-en", - "url": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en.tar.gz", - }, - "model_file": "model_optimized.onnx", - }, - { - "model": "BAAI/bge-base-en-v1.5", - "dim": 768, - "description": "Text embeddings, Unimodal (text), English, 512 input tokens truncation, Prefixes for queries/documents: not so necessary, 2023 year.", - "license": "mit", - "size_in_GB": 0.21, - "sources": { - "url": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en-v1.5.tar.gz", - "hf": "qdrant/bge-base-en-v1.5-onnx-q", - }, - "model_file": "model_optimized.onnx", - }, - { - "model": "BAAI/bge-large-en-v1.5", - "dim": 1024, - "description": "Text embeddings, Unimodal (text), English, 512 input tokens truncation, Prefixes for queries/documents: not so necessary, 2023 year.", - "license": "mit", - "size_in_GB": 1.20, - "sources": { - "hf": "qdrant/bge-large-en-v1.5-onnx", - }, - "model_file": "model.onnx", - }, - { - "model": "BAAI/bge-small-en", - "dim": 384, - "description": "Text embeddings, Unimodal (text), English, 512 input tokens truncation, Prefixes for queries/documents: necessary, 2023 year.", - "license": "mit", - "size_in_GB": 0.13, - "sources": { - "hf": "Qdrant/bge-small-en", - "url": "https://storage.googleapis.com/qdrant-fastembed/BAAI-bge-small-en.tar.gz", - }, - "model_file": "model_optimized.onnx", - }, - { - "model": "BAAI/bge-small-en-v1.5", - "dim": 384, - "description": "Text embeddings, Unimodal (text), English, 512 input tokens truncation, Prefixes for queries/documents: not so necessary, 2023 year.", - "license": "mit", - "size_in_GB": 0.067, - "sources": { - "hf": "qdrant/bge-small-en-v1.5-onnx-q", - }, - "model_file": "model_optimized.onnx", - }, - { - "model": "BAAI/bge-small-zh-v1.5", - "dim": 512, - "description": "Text embeddings, Unimodal (text), Chinese, 512 input tokens truncation, Prefixes for queries/documents: not so necessary, 2023 year.", - "license": "mit", - "size_in_GB": 0.09, - "sources": { - "hf": "Qdrant/bge-small-zh-v1.5", - "url": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-small-zh-v1.5.tar.gz", - }, - "model_file": "model_optimized.onnx", - }, - { - "model": "thenlper/gte-large", - "dim": 1024, - "description": "Text embeddings, Unimodal (text), English, 512 input tokens truncation, Prefixes for queries/documents: not necessary, 2023 year.", - "license": "mit", - "size_in_GB": 1.20, - "sources": { - "hf": "qdrant/gte-large-onnx", - }, - "model_file": "model.onnx", - }, - { - "model": "mixedbread-ai/mxbai-embed-large-v1", - "dim": 1024, - "description": "Text embeddings, Unimodal (text), English, 512 input tokens truncation, Prefixes for queries/documents: necessary, 2024 year.", - "license": "apache-2.0", - "size_in_GB": 0.64, - "sources": { - "hf": "mixedbread-ai/mxbai-embed-large-v1", - }, - "model_file": "onnx/model.onnx", - }, - { - "model": "snowflake/snowflake-arctic-embed-xs", - "dim": 384, - "description": "Text embeddings, Unimodal (text), English, 512 input tokens truncation, Prefixes for queries/documents: necessary, 2024 year.", - "license": "apache-2.0", - "size_in_GB": 0.09, - "sources": { - "hf": "snowflake/snowflake-arctic-embed-xs", - }, - "model_file": "onnx/model.onnx", - }, - { - "model": "snowflake/snowflake-arctic-embed-s", - "dim": 384, - "description": "Text embeddings, Unimodal (text), English, 512 input tokens truncation, Prefixes for queries/documents: necessary, 2024 year.", - "license": "apache-2.0", - "size_in_GB": 0.13, - "sources": { - "hf": "snowflake/snowflake-arctic-embed-s", - }, - "model_file": "onnx/model.onnx", - }, - { - "model": "snowflake/snowflake-arctic-embed-m", - "dim": 768, - "description": "Text embeddings, Unimodal (text), English, 512 input tokens truncation, Prefixes for queries/documents: necessary, 2024 year.", - "license": "apache-2.0", - "size_in_GB": 0.43, - "sources": { - "hf": "Snowflake/snowflake-arctic-embed-m", - }, - "model_file": "onnx/model.onnx", - }, - { - "model": "snowflake/snowflake-arctic-embed-m-long", - "dim": 768, - "description": "Text embeddings, Unimodal (text), English, 2048 input tokens truncation, Prefixes for queries/documents: necessary, 2024 year.", - "license": "apache-2.0", - "size_in_GB": 0.54, - "sources": { - "hf": "snowflake/snowflake-arctic-embed-m-long", - }, - "model_file": "onnx/model.onnx", - }, - { - "model": "snowflake/snowflake-arctic-embed-l", - "dim": 1024, - "description": "Text embeddings, Unimodal (text), English, 512 input tokens truncation, Prefixes for queries/documents: necessary, 2024 year.", - "license": "apache-2.0", - "size_in_GB": 1.02, - "sources": { - "hf": "snowflake/snowflake-arctic-embed-l", - }, - "model_file": "onnx/model.onnx", - }, - { - "model": "jinaai/jina-clip-v1", - "dim": 768, - "description": "Text embeddings, Multimodal (text&image), English, Prefixes for queries/documents: not necessary, 2024 year", - "license": "apache-2.0", - "size_in_GB": 0.55, - "sources": { - "hf": "jinaai/jina-clip-v1", - }, - "model_file": "onnx/text_model.onnx", - }, +supported_onnx_models: list[DenseModelDescription] = [ + DenseModelDescription( + model="BAAI/bge-base-en", + dim=768, + description=( + "Text embeddings, Unimodal (text), English, 512 input tokens truncation, " + "Prefixes for queries/documents: necessary, 2023 year." + ), + license="mit", + size_in_GB=0.42, + sources=ModelSource( + hf="Qdrant/fast-bge-base-en", + url="https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en.tar.gz", + ), + model_file="model_optimized.onnx", + ), + DenseModelDescription( + model="BAAI/bge-base-en-v1.5", + dim=768, + description=( + "Text embeddings, Unimodal (text), English, 512 input tokens truncation, " + "Prefixes for queries/documents: not so necessary, 2023 year." + ), + license="mit", + size_in_GB=0.21, + sources=ModelSource( + hf="qdrant/bge-base-en-v1.5-onnx-q", + url="https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en-v1.5.tar.gz", + ), + model_file="model_optimized.onnx", + ), + DenseModelDescription( + model="BAAI/bge-large-en-v1.5", + dim=1024, + description=( + "Text embeddings, Unimodal (text), English, 512 input tokens truncation, " + "Prefixes for queries/documents: not so necessary, 2023 year." + ), + license="mit", + size_in_GB=1.20, + sources=ModelSource(hf="qdrant/bge-large-en-v1.5-onnx"), + model_file="model.onnx", + ), + DenseModelDescription( + model="BAAI/bge-small-en", + dim=384, + description=( + "Text embeddings, Unimodal (text), English, 512 input tokens truncation, " + "Prefixes for queries/documents: necessary, 2023 year." + ), + license="mit", + size_in_GB=0.13, + sources=ModelSource( + hf="Qdrant/bge-small-en", + url="https://storage.googleapis.com/qdrant-fastembed/BAAI-bge-small-en.tar.gz", + ), + model_file="model_optimized.onnx", + ), + DenseModelDescription( + model="BAAI/bge-small-en-v1.5", + dim=384, + description=( + "Text embeddings, Unimodal (text), English, 512 input tokens truncation, " + "Prefixes for queries/documents: not so necessary, 2023 year." + ), + license="mit", + size_in_GB=0.067, + sources=ModelSource(hf="qdrant/bge-small-en-v1.5-onnx-q"), + model_file="model_optimized.onnx", + ), + DenseModelDescription( + model="BAAI/bge-small-zh-v1.5", + dim=512, + description=( + "Text embeddings, Unimodal (text), Chinese, 512 input tokens truncation, " + "Prefixes for queries/documents: not so necessary, 2023 year." + ), + license="mit", + size_in_GB=0.09, + sources=ModelSource( + hf="Qdrant/bge-small-zh-v1.5", + url="https://storage.googleapis.com/qdrant-fastembed/fast-bge-small-zh-v1.5.tar.gz", + ), + model_file="model_optimized.onnx", + ), + DenseModelDescription( + model="thenlper/gte-large", + dim=1024, + description=( + "Text embeddings, Unimodal (text), English, 512 input tokens truncation, " + "Prefixes for queries/documents: not necessary, 2023 year." + ), + license="mit", + size_in_GB=1.20, + sources=ModelSource(hf="qdrant/gte-large-onnx"), + model_file="model.onnx", + ), + DenseModelDescription( + model="mixedbread-ai/mxbai-embed-large-v1", + dim=1024, + description=( + "Text embeddings, Unimodal (text), English, 512 input tokens truncation, " + "Prefixes for queries/documents: necessary, 2024 year." + ), + license="apache-2.0", + size_in_GB=0.64, + sources=ModelSource(hf="mixedbread-ai/mxbai-embed-large-v1"), + model_file="onnx/model.onnx", + ), + DenseModelDescription( + model="snowflake/snowflake-arctic-embed-xs", + dim=384, + description=( + "Text embeddings, Unimodal (text), English, 512 input tokens truncation, " + "Prefixes for queries/documents: necessary, 2024 year." + ), + license="apache-2.0", + size_in_GB=0.09, + sources=ModelSource(hf="snowflake/snowflake-arctic-embed-xs"), + model_file="onnx/model.onnx", + ), + DenseModelDescription( + model="snowflake/snowflake-arctic-embed-s", + dim=384, + description=( + "Text embeddings, Unimodal (text), English, 512 input tokens truncation, " + "Prefixes for queries/documents: necessary, 2024 year." + ), + license="apache-2.0", + size_in_GB=0.13, + sources=ModelSource(hf="snowflake/snowflake-arctic-embed-s"), + model_file="onnx/model.onnx", + ), + DenseModelDescription( + model="snowflake/snowflake-arctic-embed-m", + dim=768, + description=( + "Text embeddings, Unimodal (text), English, 512 input tokens truncation, " + "Prefixes for queries/documents: necessary, 2024 year." + ), + license="apache-2.0", + size_in_GB=0.43, + sources=ModelSource(hf="Snowflake/snowflake-arctic-embed-m"), + model_file="onnx/model.onnx", + ), + DenseModelDescription( + model="snowflake/snowflake-arctic-embed-m-long", + dim=768, + description=( + "Text embeddings, Unimodal (text), English, 2048 input tokens truncation, " + "Prefixes for queries/documents: necessary, 2024 year." + ), + license="apache-2.0", + size_in_GB=0.54, + sources=ModelSource(hf="snowflake/snowflake-arctic-embed-m-long"), + model_file="onnx/model.onnx", + ), + DenseModelDescription( + model="snowflake/snowflake-arctic-embed-l", + dim=1024, + description=( + "Text embeddings, Unimodal (text), English, 512 input tokens truncation, " + "Prefixes for queries/documents: necessary, 2024 year." + ), + license="apache-2.0", + size_in_GB=1.02, + sources=ModelSource(hf="snowflake/snowflake-arctic-embed-l"), + model_file="onnx/model.onnx", + ), + DenseModelDescription( + model="jinaai/jina-clip-v1", + dim=768, + description=( + "Text embeddings, Multimodal (text&image), English, Prefixes for queries/documents: " + "not necessary, 2024 year" + ), + license="apache-2.0", + size_in_GB=0.55, + sources=ModelSource(hf="jinaai/jina-clip-v1"), + model_file="onnx/text_model.onnx", + ), ] @@ -174,12 +196,12 @@ class OnnxTextEmbedding(TextEmbeddingBase, OnnxTextModel[NumpyArray]): """Implementation of the Flag Embedding model.""" @classmethod - def list_supported_models(cls) -> list[dict[str, Any]]: + def _list_supported_models(cls) -> list[DenseModelDescription]: """ Lists the supported models. Returns: - list[dict[str, Any]]: A list of dictionaries containing the model information. + list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information. """ return supported_onnx_models @@ -303,7 +325,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Numpy def load_onnx_model(self) -> None: self._load_onnx_model( model_dir=self._model_dir, - model_file=self.model_description["model_file"], + model_file=self.model_description.model_file, threads=self.threads, providers=self.providers, cuda=self.cuda, diff --git a/fastembed/text/pooled_embedding.py b/fastembed/text/pooled_embedding.py index 122eed51..33236635 100644 --- a/fastembed/text/pooled_embedding.py +++ b/fastembed/text/pooled_embedding.py @@ -5,76 +5,85 @@ from fastembed.common.types import NumpyArray from fastembed.common.onnx_model import OnnxOutputContext from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker +from fastembed.common.model_description import DenseModelDescription, ModelSource -supported_pooled_models = [ - { - "model": "nomic-ai/nomic-embed-text-v1.5", - "dim": 768, - "description": "Text embeddings, Multimodal (text, image), English, 8192 input tokens truncation, Prefixes for queries/documents: necessary, 2024 year.", - "license": "apache-2.0", - "size_in_GB": 0.52, - "sources": { - "hf": "nomic-ai/nomic-embed-text-v1.5", - }, - "model_file": "onnx/model.onnx", - }, - { - "model": "nomic-ai/nomic-embed-text-v1.5-Q", - "dim": 768, - "description": "Text embeddings, Multimodal (text, image), English, 8192 input tokens truncation, Prefixes for queries/documents: necessary, 2024 year.", - "license": "apache-2.0", - "size_in_GB": 0.13, - "sources": { - "hf": "nomic-ai/nomic-embed-text-v1.5", - }, - "model_file": "onnx/model_quantized.onnx", - }, - { - "model": "nomic-ai/nomic-embed-text-v1", - "dim": 768, - "description": "Text embeddings, Multimodal (text, image), English, 8192 input tokens truncation, Prefixes for queries/documents: necessary, 2024 year.", - "license": "apache-2.0", - "size_in_GB": 0.52, - "sources": { - "hf": "nomic-ai/nomic-embed-text-v1", - }, - "model_file": "onnx/model.onnx", - }, - { - "model": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", - "dim": 384, - "description": "Text embeddings, Unimodal (text), Multilingual (~50 languages), 512 input tokens truncation, Prefixes for queries/documents: not necessary, 2019 year.", - "license": "apache-2.0", - "size_in_GB": 0.22, - "sources": { - "hf": "qdrant/paraphrase-multilingual-MiniLM-L12-v2-onnx-Q", - }, - "model_file": "model_optimized.onnx", - }, - { - "model": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", - "dim": 768, - "description": "Text embeddings, Unimodal (text), Multilingual (~50 languages), 384 input tokens truncation, Prefixes for queries/documents: not necessary, 2021 year.", - "license": "apache-2.0", - "size_in_GB": 1.00, - "sources": { - "hf": "xenova/paraphrase-multilingual-mpnet-base-v2", - }, - "model_file": "onnx/model.onnx", - }, - { - "model": "intfloat/multilingual-e5-large", - "dim": 1024, - "description": "Text embeddings, Unimodal (text), Multilingual (~100 languages), 512 input tokens truncation, Prefixes for queries/documents: necessary, 2024 year.", - "license": "mit", - "size_in_GB": 2.24, - "sources": { - "url": "https://storage.googleapis.com/qdrant-fastembed/fast-multilingual-e5-large.tar.gz", - "hf": "qdrant/multilingual-e5-large-onnx", - }, - "model_file": "model.onnx", - "additional_files": ["model.onnx_data"], - }, +supported_pooled_models: list[DenseModelDescription] = [ + DenseModelDescription( + model="nomic-ai/nomic-embed-text-v1.5", + dim=768, + description=( + "Text embeddings, Multimodal (text, image), English, 8192 input tokens truncation, " + "Prefixes for queries/documents: necessary, 2024 year." + ), + license="apache-2.0", + size_in_GB=0.52, + sources=ModelSource(hf="nomic-ai/nomic-embed-text-v1.5"), + model_file="onnx/model.onnx", + ), + DenseModelDescription( + model="nomic-ai/nomic-embed-text-v1.5-Q", + dim=768, + description=( + "Text embeddings, Multimodal (text, image), English, 8192 input tokens truncation, " + "Prefixes for queries/documents: necessary, 2024 year." + ), + license="apache-2.0", + size_in_GB=0.13, + sources=ModelSource(hf="nomic-ai/nomic-embed-text-v1.5"), + model_file="onnx/model_quantized.onnx", + ), + DenseModelDescription( + model="nomic-ai/nomic-embed-text-v1", + dim=768, + description=( + "Text embeddings, Multimodal (text, image), English, 8192 input tokens truncation, " + "Prefixes for queries/documents: necessary, 2024 year." + ), + license="apache-2.0", + size_in_GB=0.52, + sources=ModelSource(hf="nomic-ai/nomic-embed-text-v1"), + model_file="onnx/model.onnx", + ), + DenseModelDescription( + model="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", + dim=384, + description=( + "Text embeddings, Unimodal (text), Multilingual (~50 languages), 512 input tokens truncation, " + "Prefixes for queries/documents: not necessary, 2019 year." + ), + license="apache-2.0", + size_in_GB=0.22, + sources=ModelSource(hf="qdrant/paraphrase-multilingual-MiniLM-L12-v2-onnx-Q"), + model_file="model_optimized.onnx", + ), + DenseModelDescription( + model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2", + dim=768, + description=( + "Text embeddings, Unimodal (text), Multilingual (~50 languages), 384 input tokens truncation, " + "Prefixes for queries/documents: not necessary, 2021 year." + ), + license="apache-2.0", + size_in_GB=1.00, + sources=ModelSource(hf="xenova/paraphrase-multilingual-mpnet-base-v2"), + model_file="onnx/model.onnx", + ), + DenseModelDescription( + model="intfloat/multilingual-e5-large", + dim=1024, + description=( + "Text embeddings, Unimodal (text), Multilingual (~100 languages), 512 input tokens truncation, " + "Prefixes for queries/documents: necessary, 2024 year." + ), + license="mit", + size_in_GB=2.24, + sources=ModelSource( + hf="qdrant/multilingual-e5-large-onnx", + url="https://storage.googleapis.com/qdrant-fastembed/fast-multilingual-e5-large.tar.gz", + ), + model_file="model.onnx", + additional_files=["model.onnx_data"], + ), ] @@ -96,11 +105,11 @@ def mean_pooling(cls, model_output: NumpyArray, attention_mask: NumpyArray) -> N return pooled_embeddings @classmethod - def list_supported_models(cls) -> list[dict[str, Any]]: + def _list_supported_models(cls) -> list[DenseModelDescription]: """Lists the supported models. Returns: - list[dict[str, Any]]: A list of dictionaries containing the model information. + list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information. """ return supported_pooled_models diff --git a/fastembed/text/pooled_normalized_embedding.py b/fastembed/text/pooled_normalized_embedding.py index 494eae5b..f0b58b64 100644 --- a/fastembed/text/pooled_normalized_embedding.py +++ b/fastembed/text/pooled_normalized_embedding.py @@ -7,83 +7,108 @@ from fastembed.common.utils import normalize from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker from fastembed.text.pooled_embedding import PooledEmbedding +from fastembed.common.model_description import DenseModelDescription, ModelSource -supported_pooled_normalized_models = [ - { - "model": "sentence-transformers/all-MiniLM-L6-v2", - "dim": 384, - "description": "Text embeddings, Unimodal (text), English, 256 input tokens truncation, Prefixes for queries/documents: not necessary, 2021 year.", - "license": "apache-2.0", - "size_in_GB": 0.09, - "sources": { - "url": "https://storage.googleapis.com/qdrant-fastembed/sentence-transformers-all-MiniLM-L6-v2.tar.gz", - "hf": "qdrant/all-MiniLM-L6-v2-onnx", - }, - "model_file": "model.onnx", - }, - { - "model": "jinaai/jina-embeddings-v2-base-en", - "dim": 768, - "description": "Text embeddings, Unimodal (text), English, 8192 input tokens truncation, Prefixes for queries/documents: not necessary, 2023 year.", - "license": "apache-2.0", - "size_in_GB": 0.52, - "sources": {"hf": "xenova/jina-embeddings-v2-base-en"}, - "model_file": "onnx/model.onnx", - }, - { - "model": "jinaai/jina-embeddings-v2-small-en", - "dim": 512, - "description": "Text embeddings, Unimodal (text), English, 8192 input tokens truncation, Prefixes for queries/documents: not necessary, 2023 year.", - "license": "apache-2.0", - "size_in_GB": 0.12, - "sources": {"hf": "xenova/jina-embeddings-v2-small-en"}, - "model_file": "onnx/model.onnx", - }, - { - "model": "jinaai/jina-embeddings-v2-base-de", - "dim": 768, - "description": "Text embeddings, Unimodal (text), Multilingual (German, English), 8192 input tokens truncation, Prefixes for queries/documents: not necessary, 2024 year.", - "license": "apache-2.0", - "size_in_GB": 0.32, - "sources": {"hf": "jinaai/jina-embeddings-v2-base-de"}, - "model_file": "onnx/model_fp16.onnx", - }, - { - "model": "jinaai/jina-embeddings-v2-base-code", - "dim": 768, - "description": "Text embeddings, Unimodal (text), Multilingual (English, 30 programming languages), 8192 input tokens truncation, Prefixes for queries/documents: not necessary, 2024 year.", - "license": "apache-2.0", - "size_in_GB": 0.64, - "sources": {"hf": "jinaai/jina-embeddings-v2-base-code"}, - "model_file": "onnx/model.onnx", - }, - { - "model": "jinaai/jina-embeddings-v2-base-zh", - "dim": 768, - "description": "Text embeddings, Unimodal (text), supports mixed Chinese-English input text, 8192 input tokens truncation, Prefixes for queries/documents: not necessary, 2024 year.", - "license": "apache-2.0", - "size_in_GB": 0.64, - "sources": {"hf": "jinaai/jina-embeddings-v2-base-zh"}, - "model_file": "onnx/model.onnx", - }, - { - "model": "jinaai/jina-embeddings-v2-base-es", - "dim": 768, - "description": "Text embeddings, Unimodal (text), supports mixed Spanish-English input text, 8192 input tokens truncation, Prefixes for queries/documents: not necessary, 2024 year.", - "license": "apache-2.0", - "size_in_GB": 0.64, - "sources": {"hf": "jinaai/jina-embeddings-v2-base-es"}, - "model_file": "onnx/model.onnx", - }, - { - "model": "thenlper/gte-base", - "dim": 768, - "description": "General text embeddings, Unimodal (text), supports English only input text, 512 input tokens truncation, Prefixes for queries/documents: not necessary, 2024 year.", - "license": "mit", - "size_in_GB": 0.44, - "sources": {"hf": "thenlper/gte-base"}, - "model_file": "onnx/model.onnx", - }, +supported_pooled_normalized_models: list[DenseModelDescription] = [ + DenseModelDescription( + model="sentence-transformers/all-MiniLM-L6-v2", + dim=384, + description=( + "Text embeddings, Unimodal (text), English, 256 input tokens truncation, " + "Prefixes for queries/documents: not necessary, 2021 year." + ), + license="apache-2.0", + size_in_GB=0.09, + sources=ModelSource( + url="https://storage.googleapis.com/qdrant-fastembed/sentence-transformers-all-MiniLM-L6-v2.tar.gz", + hf="qdrant/all-MiniLM-L6-v2-onnx", + ), + model_file="model.onnx", + ), + DenseModelDescription( + model="jinaai/jina-embeddings-v2-base-en", + dim=768, + description=( + "Text embeddings, Unimodal (text), English, 8192 input tokens truncation, " + "Prefixes for queries/documents: not necessary, 2023 year." + ), + license="apache-2.0", + size_in_GB=0.52, + sources=ModelSource(hf="xenova/jina-embeddings-v2-base-en"), + model_file="onnx/model.onnx", + ), + DenseModelDescription( + model="jinaai/jina-embeddings-v2-small-en", + dim=512, + description=( + "Text embeddings, Unimodal (text), English, 8192 input tokens truncation, " + "Prefixes for queries/documents: not necessary, 2023 year." + ), + license="apache-2.0", + size_in_GB=0.12, + sources=ModelSource(hf="xenova/jina-embeddings-v2-small-en"), + model_file="onnx/model.onnx", + ), + DenseModelDescription( + model="jinaai/jina-embeddings-v2-base-de", + dim=768, + description=( + "Text embeddings, Unimodal (text), Multilingual (German, English), 8192 input tokens truncation, " + "Prefixes for queries/documents: not necessary, 2024 year." + ), + license="apache-2.0", + size_in_GB=0.32, + sources=ModelSource(hf="jinaai/jina-embeddings-v2-base-de"), + model_file="onnx/model_fp16.onnx", + ), + DenseModelDescription( + model="jinaai/jina-embeddings-v2-base-code", + dim=768, + description=( + "Text embeddings, Unimodal (text), Multilingual (English, 30 programming languages), " + "8192 input tokens truncation, Prefixes for queries/documents: not necessary, 2024 year." + ), + license="apache-2.0", + size_in_GB=0.64, + sources=ModelSource(hf="jinaai/jina-embeddings-v2-base-code"), + model_file="onnx/model.onnx", + ), + DenseModelDescription( + model="jinaai/jina-embeddings-v2-base-zh", + dim=768, + description=( + "Text embeddings, Unimodal (text), supports mixed Chinese-English input text, " + "8192 input tokens truncation, Prefixes for queries/documents: not necessary, 2024 year." + ), + license="apache-2.0", + size_in_GB=0.64, + sources=ModelSource(hf="jinaai/jina-embeddings-v2-base-zh"), + model_file="onnx/model.onnx", + ), + DenseModelDescription( + model="jinaai/jina-embeddings-v2-base-es", + dim=768, + description=( + "Text embeddings, Unimodal (text), supports mixed Spanish-English input text, " + "8192 input tokens truncation, Prefixes for queries/documents: not necessary, 2024 year." + ), + license="apache-2.0", + size_in_GB=0.64, + sources=ModelSource(hf="jinaai/jina-embeddings-v2-base-es"), + model_file="onnx/model.onnx", + ), + DenseModelDescription( + model="thenlper/gte-base", + dim=768, + description=( + "General text embeddings, Unimodal (text), supports English only input text, " + "512 input tokens truncation, Prefixes for queries/documents: not necessary, 2024 year." + ), + license="mit", + size_in_GB=0.44, + sources=ModelSource(hf="thenlper/gte-base"), + model_file="onnx/model.onnx", + ), ] @@ -93,11 +118,11 @@ def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]: return PooledNormalizedEmbeddingWorker @classmethod - def list_supported_models(cls) -> list[dict[str, Any]]: + def _list_supported_models(cls) -> list[DenseModelDescription]: """Lists the supported models. Returns: - list[dict[str, Any]]: A list of dictionaries containing the model information. + list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information. """ return supported_pooled_normalized_models diff --git a/fastembed/text/text_embedding.py b/fastembed/text/text_embedding.py index 3273aac1..430d7d3a 100644 --- a/fastembed/text/text_embedding.py +++ b/fastembed/text/text_embedding.py @@ -1,5 +1,6 @@ import warnings from typing import Any, Iterable, Optional, Sequence, Type, Union +from dataclasses import asdict from fastembed.common.types import NumpyArray, OnnxProvider from fastembed.text.clip_embedding import CLIPOnnxEmbedding @@ -8,6 +9,7 @@ from fastembed.text.multitask_embedding import JinaEmbeddingV3 from fastembed.text.onnx_embedding import OnnxTextEmbedding from fastembed.text.text_embedding_base import TextEmbeddingBase +from fastembed.common.model_description import DenseModelDescription class TextEmbedding(TextEmbeddingBase): @@ -21,32 +23,18 @@ class TextEmbedding(TextEmbeddingBase): @classmethod def list_supported_models(cls) -> list[dict[str, Any]]: - """ - Lists the supported models. + """Lists the supported models. Returns: list[dict[str, Any]]: A list of dictionaries containing the model information. - - Example: - ``` - [ - { - "model": "intfloat/multilingual-e5-large", - "dim": 1024, - "description": "Multilingual model, e5-large. Recommend using this model for non-English languages", - "license": "mit", - "size_in_GB": 2.24, - "sources": { - "gcp": "https://storage.googleapis.com/qdrant-fastembed/fast-multilingual-e5-large.tar.gz", - "hf": "qdrant/multilingual-e5-large-onnx", - } - } - ] - ``` """ - result: list[dict[str, Any]] = [] + return [asdict(model) for model in cls._list_supported_models()] + + @classmethod + def _list_supported_models(cls) -> list[DenseModelDescription]: + result: list[DenseModelDescription] = [] for embedding in cls.EMBEDDINGS_REGISTRY: - result.extend(embedding.list_supported_models()) + result.extend(embedding._list_supported_models()) return result def __init__( @@ -87,8 +75,8 @@ def __init__( ) for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY: - supported_models = EMBEDDING_MODEL_TYPE.list_supported_models() - if any(model_name.lower() == model["model"].lower() for model in supported_models): + supported_models = EMBEDDING_MODEL_TYPE._list_supported_models() + if any(model_name.lower() == model.model.lower() for model in supported_models): self.model = EMBEDDING_MODEL_TYPE( model_name=model_name, cache_dir=cache_dir, diff --git a/fastembed/text/text_embedding_base.py b/fastembed/text/text_embedding_base.py index d5dce815..275ad6ef 100644 --- a/fastembed/text/text_embedding_base.py +++ b/fastembed/text/text_embedding_base.py @@ -1,10 +1,11 @@ from typing import Iterable, Optional, Union, Any +from fastembed.common.model_description import DenseModelDescription from fastembed.common.types import NumpyArray from fastembed.common.model_management import ModelManagement -class TextEmbeddingBase(ModelManagement): +class TextEmbeddingBase(ModelManagement[DenseModelDescription]): def __init__( self, model_name: str, diff --git a/tests/test_common.py b/tests/test_common.py new file mode 100644 index 00000000..372c2e01 --- /dev/null +++ b/tests/test_common.py @@ -0,0 +1,30 @@ +from fastembed import ( + TextEmbedding, + SparseTextEmbedding, + ImageEmbedding, + LateInteractionMultimodalEmbedding, + LateInteractionTextEmbedding, +) + + +def test_text_list_supported_models(): + for model_type in [ + TextEmbedding, + SparseTextEmbedding, + ImageEmbedding, + LateInteractionMultimodalEmbedding, + LateInteractionTextEmbedding, + ]: + supported_models = model_type.list_supported_models() + assert isinstance(supported_models, list) + description = supported_models[0] + assert isinstance(description, dict) + + assert "model" in description and description["model"] + if model_type != SparseTextEmbedding: + assert "dim" in description and description["dim"] + assert "license" in description and description["license"] + assert "size_in_GB" in description and description["size_in_GB"] + assert "model_file" in description and description["model_file"] + assert "sources" in description and description["sources"] + assert "hf" in description["sources"] or "url" in description["sources"] diff --git a/tests/test_image_onnx_embeddings.py b/tests/test_image_onnx_embeddings.py index 1488b276..0d562279 100644 --- a/tests/test_image_onnx_embeddings.py +++ b/tests/test_image_onnx_embeddings.py @@ -30,13 +30,13 @@ def test_embedding() -> None: is_ci = os.getenv("CI") - for model_desc in ImageEmbedding.list_supported_models(): - if not is_ci and model_desc["size_in_GB"] > 1: + for model_desc in ImageEmbedding._list_supported_models(): + if not is_ci and model_desc.size_in_GB > 1: continue - dim = model_desc["dim"] + dim = model_desc.dim - model = ImageEmbedding(model_name=model_desc["model"]) + model = ImageEmbedding(model_name=model_desc.model) images = [ TEST_MISC_DIR / "image.jpeg", @@ -48,13 +48,13 @@ def test_embedding() -> None: embeddings = np.stack(embeddings, axis=0) assert embeddings.shape == (len(images), dim) - canonical_vector = CANONICAL_VECTOR_VALUES[model_desc["model"]] + canonical_vector = CANONICAL_VECTOR_VALUES[model_desc.model] assert np.allclose( embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3 - ), model_desc["model"] + ), model_desc.model - assert np.allclose(embeddings[1], embeddings[2]), model_desc["model"] + assert np.allclose(embeddings[1], embeddings[2]), model_desc.model if is_ci: delete_model_cache(model.model._model_dir) diff --git a/tests/test_text_multitask_embeddings.py b/tests/test_text_multitask_embeddings.py index b204abb5..b9bc89f1 100644 --- a/tests/test_text_multitask_embeddings.py +++ b/tests/test_text_multitask_embeddings.py @@ -65,12 +65,12 @@ def test_batch_embedding(): docs_to_embed = docs * 10 default_task = Task.RETRIEVAL_PASSAGE - for model_desc in TextEmbedding.list_supported_models(): - if not is_ci and model_desc["size_in_GB"] > 1: + for model_desc in TextEmbedding._list_supported_models(): + if not is_ci and model_desc.size_in_GB > 1: continue - model_name = model_desc["model"] - dim = model_desc["dim"] + model_name = model_desc.model + dim = model_desc.dim if model_name not in CANONICAL_VECTOR_VALUES.keys(): continue @@ -87,7 +87,7 @@ def test_batch_embedding(): canonical_vector = CANONICAL_VECTOR_VALUES[model_name][default_task]["vectors"] assert np.allclose( embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4 - ), model_desc["model"] + ), model_desc.model if is_ci: delete_model_cache(model.model._model_dir) @@ -96,12 +96,12 @@ def test_batch_embedding(): def test_single_embedding(): is_ci = os.getenv("CI") - for model_desc in TextEmbedding.list_supported_models(): - if not is_ci and model_desc["size_in_GB"] > 1: + for model_desc in TextEmbedding._list_supported_models(): + if not is_ci and model_desc.size_in_GB > 1: continue - model_name = model_desc["model"] - dim = model_desc["dim"] + model_name = model_desc.model + dim = model_desc.dim if model_name not in CANONICAL_VECTOR_VALUES.keys(): continue @@ -119,7 +119,7 @@ def test_single_embedding(): canonical_vector = task["vectors"] assert np.allclose( embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4 - ), model_desc["model"] + ), model_desc.model if is_ci: delete_model_cache(model.model._model_dir) @@ -129,12 +129,12 @@ def test_single_embedding_query(): is_ci = os.getenv("CI") task_id = Task.RETRIEVAL_QUERY - for model_desc in TextEmbedding.list_supported_models(): - if not is_ci and model_desc["size_in_GB"] > 1: + for model_desc in TextEmbedding._list_supported_models(): + if not is_ci and model_desc.size_in_GB > 1: continue - model_name = model_desc["model"] - dim = model_desc["dim"] + model_name = model_desc.model + dim = model_desc.dim if model_name not in CANONICAL_VECTOR_VALUES.keys(): continue @@ -151,7 +151,7 @@ def test_single_embedding_query(): canonical_vector = CANONICAL_VECTOR_VALUES[model_name][task_id]["vectors"] assert np.allclose( embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4 - ), model_desc["model"] + ), model_desc.model if is_ci: delete_model_cache(model.model._model_dir) @@ -161,12 +161,12 @@ def test_single_embedding_passage(): is_ci = os.getenv("CI") task_id = Task.RETRIEVAL_PASSAGE - for model_desc in TextEmbedding.list_supported_models(): - if not is_ci and model_desc["size_in_GB"] > 1: + for model_desc in TextEmbedding._list_supported_models(): + if not is_ci and model_desc.size_in_GB > 1: continue - model_name = model_desc["model"] - dim = model_desc["dim"] + model_name = model_desc.model + dim = model_desc.dim if model_name not in CANONICAL_VECTOR_VALUES.keys(): continue @@ -183,7 +183,7 @@ def test_single_embedding_passage(): canonical_vector = CANONICAL_VECTOR_VALUES[model_name][task_id]["vectors"] assert np.allclose( embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4 - ), model_desc["model"] + ), model_desc.model if is_ci: delete_model_cache(model.model._model_dir) @@ -219,11 +219,11 @@ def test_parallel_processing(): def test_task_assignment(): is_ci = os.getenv("CI") - for model_desc in TextEmbedding.list_supported_models(): - if not is_ci and model_desc["size_in_GB"] > 1: + for model_desc in TextEmbedding._list_supported_models(): + if not is_ci and model_desc.size_in_GB > 1: continue - model_name = model_desc["model"] + model_name = model_desc.model if model_name not in CANONICAL_VECTOR_VALUES.keys(): continue diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index 757e097b..e4bfa0bf 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -76,23 +76,23 @@ def test_embedding() -> None: is_ci = os.getenv("CI") is_mac = platform.system() == "Darwin" - for model_desc in TextEmbedding.list_supported_models(): + for model_desc in TextEmbedding._list_supported_models(): if ( - (not is_ci and model_desc["size_in_GB"] > 1) - or model_desc["model"] in MULTI_TASK_MODELS - or (is_mac and model_desc["model"] == "nomic-ai/nomic-embed-text-v1.5-Q") + (not is_ci and model_desc.size_in_GB > 1) + or model_desc.model in MULTI_TASK_MODELS + or (is_mac and model_desc.model == "nomic-ai/nomic-embed-text-v1.5-Q") ): continue - dim = model_desc["dim"] + dim = model_desc.dim - model = TextEmbedding(model_name=model_desc["model"]) + model = TextEmbedding(model_name=model_desc.model) docs = ["hello world", "flag embedding"] embeddings = list(model.embed(docs)) embeddings = np.stack(embeddings, axis=0) assert embeddings.shape == (2, dim) - canonical_vector = CANONICAL_VECTOR_VALUES[model_desc["model"]] + canonical_vector = CANONICAL_VECTOR_VALUES[model_desc.model] assert np.allclose( embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3 ), model_desc["model"]