Skip to content
Closed
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 39 additions & 9 deletions fastembed/common/model_management.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
import time
import json
import shutil
import tarfile
from pathlib import Path
from typing import Any, Optional
from typing import Any

import requests
from huggingface_hub import snapshot_download
Expand Down Expand Up @@ -98,7 +99,7 @@ def download_files_from_huggingface(
cls,
hf_source_repo: str,
cache_dir: str,
extra_patterns: Optional[list[str]] = None,
extra_patterns: list[str],
local_files_only: bool = False,
**kwargs,
) -> str:
Expand All @@ -107,7 +108,7 @@ def download_files_from_huggingface(
Args:
hf_source_repo (str): Name of the model on HuggingFace Hub, e.g. "qdrant/all-MiniLM-L6-v2-onnx".
cache_dir (Optional[str]): The path to the cache directory.
extra_patterns (Optional[list[str]]): extra patterns to allow in the snapshot download, typically
extra_patterns (list[str]): extra patterns to allow in the snapshot download, typically
includes the required model files.
local_files_only (bool, optional): Whether to only use local files. Defaults to False.
Returns:
Expand All @@ -120,23 +121,52 @@ def download_files_from_huggingface(
"special_tokens_map.json",
"preprocessor_config.json",
]
if extra_patterns is not None:
allow_patterns.extend(extra_patterns)

model_file = next((file for file in extra_patterns if file.endswith(".onnx")), "")
allow_patterns.extend(extra_patterns)

snapshot_dir = Path(cache_dir) / f"models--{hf_source_repo.replace('/', '--')}"
is_cached = snapshot_dir.exists()
metadata_file = snapshot_dir / "files_metadata.json"

def _verify_files_from_metadata(model_dir: Path, stored_metadata: dict[str, Any]) -> bool:
for rel_path, meta in stored_metadata.items():
file_path = model_dir / rel_path
if not file_path.exists() or file_path.stat().st_size != meta["size"]:
return False
return True

def _save_file_metadata(model_dir: Path) -> None:
metadata = {}
for file_path in model_dir.rglob("*"):
if file_path.is_file() and file_path.name != "files_metadata.json":
rel_path = str(file_path.relative_to(model_dir))
metadata[rel_path] = {
"size": file_path.stat().st_size,
}

metadata_file = model_dir / "files_metadata.json"
metadata_file.write_text(json.dumps(metadata))

if is_cached:
disable_progress_bars()
if snapshot_dir.exists() and metadata_file.exists():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if version on hf is different from the one we have locally, then we will hide the progress bar and then we will silently download the updated files
I think we could make a corresponding call to HfApi and check revision and commit hash

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can u explain it further ? Cuz as far as Andrey's said, he don't want to make a call to HFapi as it requires network. Do u mean that we can call it only while downloading on the first time and add revision to metadata ? And only call it when there;s network ?

Copy link
Contributor Author

@hh-space-invader hh-space-invader Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a _get_file_hash function to compute hash for each file and then later on checked with _verify_files_from_metadata if the version changed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's safe to make this call if local_files_only != True

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iirc in snapshot_download they just pull the whole repo info and compare revision's commit hahs

stored_metadata = json.loads(metadata_file.read_text())
if _verify_files_from_metadata(snapshot_dir, stored_metadata):
disable_progress_bars()

return snapshot_download(
result = snapshot_download(
repo_id=hf_source_repo,
allow_patterns=allow_patterns,
cache_dir=cache_dir,
local_files_only=local_files_only,
**kwargs,
)

if not os.path.exists(os.path.join(result, model_file)):
raise FileNotFoundError("Couldn't download model from huggingface")
Comment on lines +189 to +190
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if any of the other required files is missing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I know, the blobs are downloaded first, then its extracted to .onnx and the required files. So if the .onnx exists, then the blobs extracted everything correctly. U can check that by deleting the snapshot folder and try to run the model. The snapshot folder will be extracted again and the .onnx (and other files) will exist.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this check is redundant, let's rely on hf here
also we have models without onnx files like Qdrant/bm25 and then it just checks existence of a directory

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and we can also remove model_file variable then


if not local_files_only:
_save_file_metadata(snapshot_dir)
return result

@classmethod
def decompress_to_cache(cls, targz_path: str, cache_dir: str):
"""
Expand Down
Loading