11import os
22import time
3+ import json
34import shutil
45import tarfile
56from pathlib import Path
6- from typing import Any , Optional
7+ from typing import Any
78
89import requests
9- from huggingface_hub import snapshot_download
10+ from huggingface_hub import snapshot_download , model_info , list_repo_tree
11+ from huggingface_hub .hf_api import RepoFile
1012from huggingface_hub .utils import (
1113 RepositoryNotFoundError ,
1214 disable_progress_bars ,
1719
1820
1921class ModelManagement :
22+ METADATA_FILE = "files_metadata.json"
23+
2024 @classmethod
2125 def list_supported_models (cls ) -> list [dict [str , Any ]]:
2226 """Lists the supported models.
@@ -98,7 +102,7 @@ def download_files_from_huggingface(
98102 cls ,
99103 hf_source_repo : str ,
100104 cache_dir : str ,
101- extra_patterns : Optional [ list [str ]] = None ,
105+ extra_patterns : list [str ],
102106 local_files_only : bool = False ,
103107 ** kwargs ,
104108 ) -> str :
@@ -107,36 +111,148 @@ def download_files_from_huggingface(
107111 Args:
108112 hf_source_repo (str): Name of the model on HuggingFace Hub, e.g. "qdrant/all-MiniLM-L6-v2-onnx".
109113 cache_dir (Optional[str]): The path to the cache directory.
110- extra_patterns (Optional[ list[str] ]): extra patterns to allow in the snapshot download, typically
114+ extra_patterns (list[str]): extra patterns to allow in the snapshot download, typically
111115 includes the required model files.
112116 local_files_only (bool, optional): Whether to only use local files. Defaults to False.
113117 Returns:
114118 Path: The path to the model directory.
115119 """
120+
121+ def _verify_files_from_metadata (
122+ model_dir : Path , stored_metadata : dict [str , Any ], repo_files : list [RepoFile ]
123+ ) -> bool :
124+ try :
125+ for rel_path , meta in stored_metadata .items ():
126+ file_path = model_dir / rel_path
127+
128+ if not file_path .exists ():
129+ return False
130+
131+ if repo_files : # online verification
132+ file_info = next ((f for f in repo_files if f .path == file_path .name ), None )
133+ if (
134+ not file_info
135+ or file_info .size != meta ["size" ]
136+ or file_info .blob_id != meta ["blob_id" ]
137+ ):
138+ return False
139+
140+ else : # offline verification
141+ if file_path .stat ().st_size != meta ["size" ]:
142+ return False
143+ return True
144+ except (OSError , KeyError ) as e :
145+ logger .error (f"Error verifying files: { str (e )} " )
146+ return False
147+
148+ def _collect_file_metadata (
149+ model_dir : Path , repo_files : list [RepoFile ]
150+ ) -> dict [str , dict [str , int ]]:
151+ meta = {}
152+ file_info_map = {f .path : f for f in repo_files }
153+ for file_path in model_dir .rglob ("*" ):
154+ if file_path .is_file () and file_path .name != cls .METADATA_FILE :
155+ repo_file = file_info_map .get (file_path .name )
156+ if repo_file :
157+ meta [str (file_path .relative_to (model_dir ))] = {
158+ "size" : repo_file .size ,
159+ "blob_id" : repo_file .blob_id ,
160+ }
161+ return meta
162+
163+ def _save_file_metadata (model_dir : Path , meta : dict [str , dict [str , int ]]) -> None :
164+ try :
165+ if not model_dir .exists ():
166+ model_dir .mkdir (parents = True , exist_ok = True )
167+ (model_dir / cls .METADATA_FILE ).write_text (json .dumps (meta ))
168+ except (OSError , ValueError ) as e :
169+ logger .warning (f"Error saving metadata: { str (e )} " )
170+
116171 allow_patterns = [
117172 "config.json" ,
118173 "tokenizer.json" ,
119174 "tokenizer_config.json" ,
120175 "special_tokens_map.json" ,
121176 "preprocessor_config.json" ,
122177 ]
123- if extra_patterns is not None :
124- allow_patterns .extend (extra_patterns )
178+
179+ allow_patterns .extend (extra_patterns )
125180
126181 snapshot_dir = Path (cache_dir ) / f"models--{ hf_source_repo .replace ('/' , '--' )} "
127- is_cached = snapshot_dir .exists ()
182+ metadata_file = snapshot_dir / cls .METADATA_FILE
183+
184+ if local_files_only :
185+ disable_progress_bars ()
186+ if metadata_file .exists ():
187+ metadata = json .loads (metadata_file .read_text ())
188+ verified = _verify_files_from_metadata (snapshot_dir , metadata , repo_files = [])
189+ if not verified :
190+ logger .warning (
191+ "Local file sizes do not match the metadata."
192+ ) # do not raise, still make an attempt to load the model
193+ else :
194+ logger .warning (
195+ "Metadata file not found. Proceeding without checking local files."
196+ ) # if users have downloaded models from hf manually, or they're updating from previous versions of
197+ # fastembed
198+ result = snapshot_download (
199+ repo_id = hf_source_repo ,
200+ allow_patterns = allow_patterns ,
201+ cache_dir = cache_dir ,
202+ local_files_only = local_files_only ,
203+ ** kwargs ,
204+ )
205+ return result
206+
207+ repo_revision = model_info (hf_source_repo ).sha
208+ repo_tree = list (list_repo_tree (hf_source_repo , revision = repo_revision , repo_type = "model" ))
209+
210+ allowed_extensions = {".json" , ".onnx" , ".txt" }
211+ repo_files = (
212+ [
213+ f
214+ for f in repo_tree
215+ if isinstance (f , RepoFile ) and Path (f .path ).suffix in allowed_extensions
216+ ]
217+ if repo_tree
218+ else []
219+ )
220+
221+ verified_metadata = False
222+
223+ if snapshot_dir .exists () and metadata_file .exists ():
224+ metadata = json .loads (metadata_file .read_text ())
225+ verified_metadata = _verify_files_from_metadata (snapshot_dir , metadata , repo_files )
128226
129- if is_cached :
227+ if verified_metadata :
130228 disable_progress_bars ()
131229
132- return snapshot_download (
230+ result = snapshot_download (
133231 repo_id = hf_source_repo ,
134232 allow_patterns = allow_patterns ,
135233 cache_dir = cache_dir ,
136234 local_files_only = local_files_only ,
137235 ** kwargs ,
138236 )
139237
238+ if (
239+ not verified_metadata
240+ ): # metadata is not up-to-date, update it and check whether the files have been
241+ # downloaded correctly
242+ metadata = _collect_file_metadata (snapshot_dir , repo_files )
243+
244+ download_successful = _verify_files_from_metadata (
245+ snapshot_dir , metadata , repo_files = []
246+ ) # offline verification
247+ if not download_successful :
248+ raise ValueError (
249+ "Files have been corrupted during downloading process. "
250+ "Please check your internet connection and try again."
251+ )
252+ _save_file_metadata (snapshot_dir , metadata )
253+
254+ return result
255+
140256 @classmethod
141257 def decompress_to_cache (cls , targz_path : str , cache_dir : str ):
142258 """
0 commit comments