Skip to content

Commit 44e3329

Browse files
authored
new: try loading models from cache before making any network calls (#577)
1 parent 533b54c commit 44e3329

File tree

1 file changed

+26
-11
lines changed

1 file changed

+26
-11
lines changed

fastembed/common/model_management.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import shutil
55
import tarfile
6+
from copy import deepcopy
67
from pathlib import Path
78
from typing import Any, Optional, Union, TypeVar, Generic
89

@@ -224,11 +225,6 @@ def _save_file_metadata(
224225
logger.warning(
225226
"Local file sizes do not match the metadata."
226227
) # do not raise, still make an attempt to load the model
227-
else:
228-
logger.warning(
229-
"Metadata file not found. Proceeding without checking local files."
230-
) # if users have downloaded models from hf manually, or they're updating from previous versions of
231-
# fastembed
232228
result = snapshot_download(
233229
repo_id=hf_source_repo,
234230
allow_patterns=allow_patterns,
@@ -408,14 +404,32 @@ def download_model(cls, model: T, cache_dir: str, retries: int = 3, **kwargs: An
408404
hf_source = model.sources.hf
409405
url_source = model.sources.url
410406

407+
extra_patterns = [model.model_file]
408+
extra_patterns.extend(model.additional_files)
409+
410+
if hf_source:
411+
try:
412+
cache_kwargs = deepcopy(kwargs)
413+
cache_kwargs["local_files_only"] = True
414+
return Path(
415+
cls.download_files_from_huggingface(
416+
hf_source,
417+
cache_dir=cache_dir,
418+
extra_patterns=extra_patterns,
419+
**cache_kwargs,
420+
)
421+
)
422+
except Exception:
423+
pass
424+
finally:
425+
enable_progress_bars()
426+
411427
sleep = 3.0
412428
while retries > 0:
413429
retries -= 1
414430

415-
if hf_source:
416-
extra_patterns = [model.model_file]
417-
extra_patterns.extend(model.additional_files)
418-
431+
if hf_source and not local_files_only:
432+
# we have already tried loading with `local_files_only=True` via hf and we failed
419433
try:
420434
return Path(
421435
cls.download_files_from_huggingface(
@@ -448,11 +462,12 @@ def download_model(cls, model: T, cache_dir: str, retries: int = 3, **kwargs: An
448462

449463
if local_files_only:
450464
logger.error("Could not find model in cache_dir")
465+
break
451466
else:
452467
logger.error(
453468
f"Could not download model from either source, sleeping for {sleep} seconds, {retries} retries left."
454469
)
455-
time.sleep(sleep)
456-
sleep *= 3
470+
time.sleep(sleep)
471+
sleep *= 3
457472

458473
raise ValueError(f"Could not load model {model.model} from any source.")

0 commit comments

Comments
 (0)