Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
66aad94
Migration of models to dataclasses
I8dNLo Feb 11, 2025
303831f
Model description file
I8dNLo Feb 11, 2025
08efaf7
Test fix
I8dNLo Feb 11, 2025
05b2ce9
kw_only support
I8dNLo Feb 11, 2025
ed1f319
Multitask embeddings test fix
I8dNLo Feb 11, 2025
9ccc9ff
list_supported_models type fix
I8dNLo Feb 11, 2025
5608cef
Dim fix for sparsemodels
I8dNLo Feb 11, 2025
3af5174
Dim fix for sparsemodels
I8dNLo Feb 11, 2025
6ccab21
Dim fix for sparsemodels (x2)
I8dNLo Feb 11, 2025
31da715
Model management type fix
I8dNLo Feb 11, 2025
fbc7077
Interface docstring fixes
I8dNLo Feb 11, 2025
7bd3032
Mypy fixes
I8dNLo Feb 11, 2025
6c299dc
Typing fix again
I8dNLo Feb 11, 2025
dd27ed6
Typing fix again
I8dNLo Feb 11, 2025
8a870f7
Special cast to SparseModelDescription
I8dNLo Feb 11, 2025
0377647
Special cast to SparseModelDescription
I8dNLo Feb 11, 2025
5a0d846
Special cast to SparseModelDescription
I8dNLo Feb 11, 2025
042dc15
typing fix for colpali
I8dNLo Feb 11, 2025
818e875
typing fix for colpali
I8dNLo Feb 11, 2025
6d43348
typing fix for colpali
I8dNLo Feb 11, 2025
5bdf784
typing fix for colpali
I8dNLo Feb 11, 2025
a5fa6ab
Let's try generic typing for ModelManagment
I8dNLo Feb 11, 2025
3da7d62
wip: dataclass idea, small fixes (#475)
joein Feb 14, 2025
8ccf3ea
remove custom model descriptions
joein Feb 15, 2025
d8c46a8
make license, description and size in gb mandatory in model description
joein Feb 15, 2025
997272c
fix: introduce _list_supported_models which returns model description…
joein Feb 15, 2025
954dc70
test: add test for list supported models
joein Feb 15, 2025
85e6241
fix: fix list supported models usage in tests
joein Feb 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions fastembed/common/model_description.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from dataclasses import dataclass, field
from typing import Optional, Any


@dataclass(frozen=True)
class ModelSource:
hf: Optional[str] = None
url: Optional[str] = None

def __post_init__(self) -> None:
if self.hf is None and self.url is None:
raise ValueError(
f"At least one source should be set, current sources: hf={self.hf}, url={self.url}"
)


@dataclass(frozen=True)
class BaseModelDescription:
model: str
sources: ModelSource
model_file: str
description: str
license: str
size_in_GB: float
additional_files: list[str] = field(default_factory=list)


@dataclass(frozen=True)
class DenseModelDescription(BaseModelDescription):
dim: Optional[int] = None
tasks: Optional[dict[str, Any]] = None

def __post_init__(self) -> None:
assert self.dim is not None, "dim is required for dense model description"


@dataclass(frozen=True)
class SparseModelDescription(BaseModelDescription):
requires_idf: Optional[bool] = None
vocab_size: Optional[int] = None
51 changes: 31 additions & 20 deletions fastembed/common/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import shutil
import tarfile
from pathlib import Path
from typing import Any, Optional, Union
from typing import Any, Optional, Union, TypeVar, Generic

import requests
from huggingface_hub import snapshot_download, model_info, list_repo_tree
Expand All @@ -16,22 +16,29 @@
)
from loguru import logger
from tqdm import tqdm
from fastembed.common.model_description import BaseModelDescription

T = TypeVar("T", bound=BaseModelDescription)

class ModelManagement:

class ModelManagement(Generic[T]):
METADATA_FILE = "files_metadata.json"

@classmethod
def list_supported_models(cls) -> list[dict[str, Any]]:
"""Lists the supported models.

Returns:
list[dict[str, Any]]: A list of dictionaries containing the model information.
list[T]: A list of dictionaries containing the model information.
"""
raise NotImplementedError()

@classmethod
def _get_model_description(cls, model_name: str) -> dict[str, Any]:
def _list_supported_models(cls) -> list[T]:
raise NotImplementedError()

@classmethod
def _get_model_description(cls, model_name: str) -> T:
"""
Gets the model description from the model_name.

Expand All @@ -42,10 +49,10 @@ def _get_model_description(cls, model_name: str) -> dict[str, Any]:
ValueError: If the model_name is not supported.

Returns:
dict[str, Any]: The model description.
T: The model description.
"""
for model in cls.list_supported_models():
if model_name.lower() == model["model"].lower():
for model in cls._list_supported_models():
if model_name.lower() == model.model.lower():
return model

raise ValueError(f"Model {model_name} is not supported in {cls.__name__}.")
Expand Down Expand Up @@ -160,7 +167,9 @@ def _collect_file_metadata(
}
return meta

def _save_file_metadata(model_dir: Path, meta: dict[str, dict[str, Union[int, str]]]) -> None:
def _save_file_metadata(
model_dir: Path, meta: dict[str, dict[str, Union[int, str]]]
) -> None:
try:
if not model_dir.exists():
model_dir.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -292,7 +301,11 @@ def decompress_to_cache(cls, targz_path: str, cache_dir: str) -> str:

@classmethod
def retrieve_model_gcs(
cls, model_name: str, source_url: str, cache_dir: str, local_files_only: bool = False
cls,
model_name: str,
source_url: str,
cache_dir: str,
local_files_only: bool = False,
) -> Path:
fast_model_name = f"fast-{model_name.split('/')[-1]}"
cache_tmp_dir = Path(cache_dir) / "tmp"
Expand Down Expand Up @@ -336,14 +349,12 @@ def retrieve_model_gcs(
return model_dir

@classmethod
def download_model(
cls, model: dict[str, Any], cache_dir: str, retries: int = 3, **kwargs: Any
) -> Path:
def download_model(cls, model: T, cache_dir: str, retries: int = 3, **kwargs: Any) -> Path:
"""
Downloads a model from HuggingFace Hub or Google Cloud Storage.

Args:
model (dict[str, Any]): The model description.
model (T): The model description.
Example:
```
{
Expand All @@ -368,16 +379,16 @@ def download_model(
if specific_model_path:
return Path(specific_model_path)
retries = 1 if local_files_only else retries
hf_source = model.get("sources", {}).get("hf")
url_source = model.get("sources", {}).get("url")
hf_source = model.sources.hf
url_source = model.sources.url

sleep = 3.0
while retries > 0:
retries -= 1

if hf_source:
extra_patterns = [model["model_file"]]
extra_patterns.extend(model.get("additional_files", []))
extra_patterns = [model.model_file]
extra_patterns.extend(model.additional_files)

try:
return Path(
Expand All @@ -399,8 +410,8 @@ def download_model(
if url_source or local_files_only:
try:
return cls.retrieve_model_gcs(
model["model"],
url_source,
model.model,
str(url_source),
str(cache_dir),
local_files_only=local_files_only,
)
Expand All @@ -417,4 +428,4 @@ def download_model(
time.sleep(sleep)
sleep *= 3

raise ValueError(f"Could not load model {model['model']} from any source.")
raise ValueError(f"Could not load model {model.model} from any source.")
14 changes: 10 additions & 4 deletions fastembed/image/image_embedding.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Any, Iterable, Optional, Sequence, Type, Union
from dataclasses import asdict

from fastembed.common.types import NumpyArray
from fastembed.common import ImageInput, OnnxProvider
from fastembed.image.image_embedding_base import ImageEmbeddingBase
from fastembed.image.onnx_embedding import OnnxImageEmbedding
from fastembed.common.model_description import DenseModelDescription


class ImageEmbedding(ImageEmbeddingBase):
Expand Down Expand Up @@ -34,9 +36,13 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
]
```
"""
result: list[dict[str, Any]] = []
return [asdict(model) for model in cls._list_supported_models()]

@classmethod
def _list_supported_models(cls) -> list[DenseModelDescription]:
result: list[DenseModelDescription] = []
for embedding in cls.EMBEDDINGS_REGISTRY:
result.extend(embedding.list_supported_models())
result.extend(embedding._list_supported_models())
return result

def __init__(
Expand All @@ -52,8 +58,8 @@ def __init__(
):
super().__init__(model_name, cache_dir, threads, **kwargs)
for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY:
supported_models = EMBEDDING_MODEL_TYPE.list_supported_models()
if any(model_name.lower() == model["model"].lower() for model in supported_models):
supported_models = EMBEDDING_MODEL_TYPE._list_supported_models()
if any(model_name.lower() == model.model.lower() for model in supported_models):
self.model = EMBEDDING_MODEL_TYPE(
model_name,
cache_dir,
Expand Down
3 changes: 2 additions & 1 deletion fastembed/image/image_embedding_base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Iterable, Optional, Any, Union

from fastembed.common.model_description import DenseModelDescription
from fastembed.common.types import NumpyArray
from fastembed.common.model_management import ModelManagement
from fastembed.common.types import ImageInput


class ImageEmbeddingBase(ModelManagement):
class ImageEmbeddingBase(ModelManagement[DenseModelDescription]):
def __init__(
self,
model_name: str,
Expand Down
110 changes: 51 additions & 59 deletions fastembed/image/onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,62 +9,54 @@
from fastembed.image.image_embedding_base import ImageEmbeddingBase
from fastembed.image.onnx_image_model import ImageEmbeddingWorker, OnnxImageModel

supported_onnx_models = [
{
"model": "Qdrant/clip-ViT-B-32-vision",
"dim": 512,
"description": "Image embeddings, Multimodal (text&image), 2021 year",
"license": "mit",
"size_in_GB": 0.34,
"sources": {
"hf": "Qdrant/clip-ViT-B-32-vision",
},
"model_file": "model.onnx",
},
{
"model": "Qdrant/resnet50-onnx",
"dim": 2048,
"description": "Image embeddings, Unimodal (image), 2016 year",
"license": "apache-2.0",
"size_in_GB": 0.1,
"sources": {
"hf": "Qdrant/resnet50-onnx",
},
"model_file": "model.onnx",
},
{
"model": "Qdrant/Unicom-ViT-B-16",
"dim": 768,
"description": "Image embeddings (more detailed than Unicom-ViT-B-32), Multimodal (text&image), 2023 year",
"license": "apache-2.0",
"size_in_GB": 0.82,
"sources": {
"hf": "Qdrant/Unicom-ViT-B-16",
},
"model_file": "model.onnx",
},
{
"model": "Qdrant/Unicom-ViT-B-32",
"dim": 512,
"description": "Image embeddings, Multimodal (text&image), 2023 year",
"license": "apache-2.0",
"size_in_GB": 0.48,
"sources": {
"hf": "Qdrant/Unicom-ViT-B-32",
},
"model_file": "model.onnx",
},
{
"model": "jinaai/jina-clip-v1",
"dim": 768,
"description": "Image embeddings, Multimodal (text&image), 2024 year",
"license": "apache-2.0",
"size_in_GB": 0.34,
"sources": {
"hf": "jinaai/jina-clip-v1",
},
"model_file": "onnx/vision_model.onnx",
},
from fastembed.common.model_description import DenseModelDescription, ModelSource

supported_onnx_models: list[DenseModelDescription] = [
DenseModelDescription(
model="Qdrant/clip-ViT-B-32-vision",
dim=512,
description="Image embeddings, Multimodal (text&image), 2021 year",
license="mit",
size_in_GB=0.34,
sources=ModelSource(hf="Qdrant/clip-ViT-B-32-vision"),
model_file="model.onnx",
),
DenseModelDescription(
model="Qdrant/resnet50-onnx",
dim=2048,
description="Image embeddings, Unimodal (image), 2016 year",
license="apache-2.0",
size_in_GB=0.1,
sources=ModelSource(hf="Qdrant/resnet50-onnx"),
model_file="model.onnx",
),
DenseModelDescription(
model="Qdrant/Unicom-ViT-B-16",
dim=768,
description="Image embeddings (more detailed than Unicom-ViT-B-32), Multimodal (text&image), 2023 year",
license="apache-2.0",
size_in_GB=0.82,
sources=ModelSource(hf="Qdrant/Unicom-ViT-B-16"),
model_file="model.onnx",
),
DenseModelDescription(
model="Qdrant/Unicom-ViT-B-32",
dim=512,
description="Image embeddings, Multimodal (text&image), 2023 year",
license="apache-2.0",
size_in_GB=0.48,
sources=ModelSource(hf="Qdrant/Unicom-ViT-B-32"),
model_file="model.onnx",
),
DenseModelDescription(
model="jinaai/jina-clip-v1",
dim=768,
description="Image embeddings, Multimodal (text&image), 2024 year",
license="apache-2.0",
size_in_GB=0.34,
sources=ModelSource(hf="jinaai/jina-clip-v1"),
model_file="onnx/vision_model.onnx",
),
]


Expand Down Expand Up @@ -137,20 +129,20 @@ def load_onnx_model(self) -> None:
"""
self._load_onnx_model(
model_dir=self._model_dir,
model_file=self.model_description["model_file"],
model_file=self.model_description.model_file,
threads=self.threads,
providers=self.providers,
cuda=self.cuda,
device_id=self.device_id,
)

@classmethod
def list_supported_models(cls) -> list[dict[str, Any]]:
def _list_supported_models(cls) -> list[DenseModelDescription]:
"""
Lists the supported models.

Returns:
list[dict[str, Any]]: A list of dictionaries containing the model information.
list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
"""
return supported_onnx_models

Expand Down
Loading