| 
22 | 22 | def _download_hf_snapshot(  | 
23 | 23 |     model_config: ModelConfig, artifact_dir: Path, hf_token: Optional[str]  | 
24 | 24 | ):  | 
25 |  | -    from huggingface_hub import snapshot_download  | 
 | 25 | +    from huggingface_hub import model_info, snapshot_download  | 
26 | 26 |     from requests.exceptions import HTTPError  | 
27 | 27 | 
 
  | 
28 | 28 |     # Download and store the HF model artifacts.  | 
29 | 29 |     print(f"Downloading {model_config.name} from HuggingFace...", file=sys.stderr)  | 
30 | 30 |     try:  | 
 | 31 | +        # Fetch the info about the model's repo  | 
 | 32 | +        model_info = model_info(model_config.distribution_path, token=hf_token)  | 
 | 33 | +        model_fnames = [f.rfilename for f in model_info.siblings]  | 
 | 34 | + | 
 | 35 | +        # Check the model config for preference between safetensors and pth  | 
 | 36 | +        has_pth = any(f.endswith(".pth") for f in model_fnames)  | 
 | 37 | +        has_safetensors = any(f.endswith(".safetensors") for f in model_fnames)  | 
 | 38 | + | 
 | 39 | +        # If told to prefer safetensors, ignore pth files  | 
 | 40 | +        if model_config.prefer_safetensors:  | 
 | 41 | +            if not has_safetensors:  | 
 | 42 | +                print(  | 
 | 43 | +                    f"Model {model_config.name} does not have safetensors files, but prefer_safetensors is set to True. Using pth files instead.",  | 
 | 44 | +                    file=sys.stderr,  | 
 | 45 | +                )  | 
 | 46 | +                exit(1)  | 
 | 47 | +            ignore_patterns = "*.pth"  | 
 | 48 | + | 
 | 49 | +        # If the model has both, prefer pth files over safetensors  | 
 | 50 | +        elif has_pth and has_safetensors:  | 
 | 51 | +            ignore_patterns = "*safetensors*"  | 
 | 52 | + | 
 | 53 | +        # Otherwise, download everything  | 
 | 54 | +        else:  | 
 | 55 | +            ignore_patterns = None  | 
 | 56 | + | 
31 | 57 |         snapshot_download(  | 
32 | 58 |             model_config.distribution_path,  | 
33 | 59 |             local_dir=artifact_dir,  | 
34 | 60 |             local_dir_use_symlinks=False,  | 
35 | 61 |             token=hf_token,  | 
36 |  | -            ignore_patterns="*safetensors*",  | 
 | 62 | +            ignore_patterns=ignore_patterns,  | 
37 | 63 |         )  | 
38 | 64 |     except HTTPError as e:  | 
39 | 65 |         if e.response.status_code == 401:  # Missing HuggingFace CLI login.  | 
 | 
0 commit comments