|
3 | 3 | import json |
4 | 4 | import shutil |
5 | 5 | import tarfile |
| 6 | +from copy import deepcopy |
6 | 7 | from pathlib import Path |
7 | 8 | from typing import Any, Optional, Union, TypeVar, Generic |
8 | 9 |
|
@@ -224,11 +225,6 @@ def _save_file_metadata( |
224 | 225 | logger.warning( |
225 | 226 | "Local file sizes do not match the metadata." |
226 | 227 | ) # do not raise, still make an attempt to load the model |
227 | | - else: |
228 | | - logger.warning( |
229 | | - "Metadata file not found. Proceeding without checking local files." |
230 | | - ) # if users have downloaded models from hf manually, or they're updating from previous versions of |
231 | | - # fastembed |
232 | 228 | result = snapshot_download( |
233 | 229 | repo_id=hf_source_repo, |
234 | 230 | allow_patterns=allow_patterns, |
@@ -408,14 +404,32 @@ def download_model(cls, model: T, cache_dir: str, retries: int = 3, **kwargs: An |
408 | 404 | hf_source = model.sources.hf |
409 | 405 | url_source = model.sources.url |
410 | 406 |
|
| 407 | + extra_patterns = [model.model_file] |
| 408 | + extra_patterns.extend(model.additional_files) |
| 409 | + |
| 410 | + if hf_source: |
| 411 | + try: |
| 412 | + cache_kwargs = deepcopy(kwargs) |
| 413 | + cache_kwargs["local_files_only"] = True |
| 414 | + return Path( |
| 415 | + cls.download_files_from_huggingface( |
| 416 | + hf_source, |
| 417 | + cache_dir=cache_dir, |
| 418 | + extra_patterns=extra_patterns, |
| 419 | + **cache_kwargs, |
| 420 | + ) |
| 421 | + ) |
| 422 | + except Exception: |
| 423 | + pass |
| 424 | + finally: |
| 425 | + enable_progress_bars() |
| 426 | + |
411 | 427 | sleep = 3.0 |
412 | 428 | while retries > 0: |
413 | 429 | retries -= 1 |
414 | 430 |
|
415 | | - if hf_source: |
416 | | - extra_patterns = [model.model_file] |
417 | | - extra_patterns.extend(model.additional_files) |
418 | | - |
| 431 | + if hf_source and not local_files_only: |
| 432 | + # we have already tried loading with `local_files_only=True` via hf and we failed |
419 | 433 | try: |
420 | 434 | return Path( |
421 | 435 | cls.download_files_from_huggingface( |
@@ -448,11 +462,12 @@ def download_model(cls, model: T, cache_dir: str, retries: int = 3, **kwargs: An |
448 | 462 |
|
449 | 463 | if local_files_only: |
450 | 464 | logger.error("Could not find model in cache_dir") |
| 465 | + break |
451 | 466 | else: |
452 | 467 | logger.error( |
453 | 468 | f"Could not download model from either source, sleeping for {sleep} seconds, {retries} retries left." |
454 | 469 | ) |
455 | | - time.sleep(sleep) |
456 | | - sleep *= 3 |
| 470 | + time.sleep(sleep) |
| 471 | + sleep *= 3 |
457 | 472 |
|
458 | 473 | raise ValueError(f"Could not load model {model.model} from any source.") |
0 commit comments