diff --git a/torchchat/cli/convert_hf_checkpoint.py b/torchchat/cli/convert_hf_checkpoint.py index 1e5d3eaf7..7e3e2d676 100644 --- a/torchchat/cli/convert_hf_checkpoint.py +++ b/torchchat/cli/convert_hf_checkpoint.py @@ -3,6 +3,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import glob import json import os import re @@ -41,7 +42,12 @@ def convert_hf_checkpoint( print(f"Model config {config.__dict__}") # Load the json file containing weight mapping - model_map_json = model_dir / "pytorch_model.bin.index.json" + model_map_json_matches = [Path(m) for m in glob.glob(str(model_dir / "*.index.json"))] + assert len(model_map_json_matches) <= 1, "Found multiple weight mapping files" + if len(model_map_json_matches): + model_map_json = model_map_json_matches[0] + else: + model_map_json = model_dir / "pytorch_model.bin.index.json" # If there is no weight mapping, check for a consolidated model and # tokenizer we can move. Llama 2 and Mistral have weight mappings, while @@ -96,9 +102,33 @@ def permute(w, n_heads): merged_result = {} for file in sorted(bin_files): - state_dict = torch.load( + + # The state_dict can be loaded from either a torch zip file or + # safetensors. We take our best guess from the name and try all + # possibilities + load_pt_mmap = lambda: torch.load( str(file), map_location="cpu", mmap=True, weights_only=True ) + load_pt_no_mmap = lambda: torch.load( + str(file), map_location="cpu", mmap=False, weights_only=True + ) + def load_safetensors(): + import safetensors.torch + with open(file, "rb") as handle: + return safetensors.torch.load(handle.read()) + if "safetensors" in str(file): + loaders = [load_safetensors, load_pt_mmap, load_pt_no_mmap] + else: + loaders = [load_pt_mmap, load_pt_no_mmap, load_safetensors] + + state_dict = None + for loader in loaders: + try: + state_dict = loader() + break + except Exception: + continue + assert state_dict is not None, f"Unable to load tensors from {file}" merged_result.update(state_dict) final_result = {} for key, value in merged_result.items(): diff --git a/torchchat/cli/download.py b/torchchat/cli/download.py index 6ac3e8d9d..14dfeb062 100644 --- a/torchchat/cli/download.py +++ b/torchchat/cli/download.py @@ -22,18 +22,44 @@ def _download_hf_snapshot( model_config: ModelConfig, artifact_dir: Path, hf_token: Optional[str] ): - from huggingface_hub import snapshot_download + from huggingface_hub import model_info, snapshot_download from requests.exceptions import HTTPError # Download and store the HF model artifacts. print(f"Downloading {model_config.name} from HuggingFace...", file=sys.stderr) try: + # Fetch the info about the model's repo + model_info = model_info(model_config.distribution_path, token=hf_token) + model_fnames = [f.rfilename for f in model_info.siblings] + + # Check the model config for preference between safetensors and pth + has_pth = any(f.endswith(".pth") for f in model_fnames) + has_safetensors = any(f.endswith(".safetensors") for f in model_fnames) + + # If told to prefer safetensors, ignore pth files + if model_config.prefer_safetensors: + if not has_safetensors: + print( + f"Model {model_config.name} does not have safetensors files, but prefer_safetensors is set to True. Using pth files instead.", + file=sys.stderr, + ) + exit(1) + ignore_patterns = "*.pth" + + # If the model has both, prefer pth files over safetensors + elif has_pth and has_safetensors: + ignore_patterns = "*safetensors*" + + # Otherwise, download everything + else: + ignore_patterns = None + snapshot_download( model_config.distribution_path, local_dir=artifact_dir, local_dir_use_symlinks=False, token=hf_token, - ignore_patterns="*safetensors*", + ignore_patterns=ignore_patterns, ) except HTTPError as e: if e.response.status_code == 401: # Missing HuggingFace CLI login. diff --git a/torchchat/model_config/model_config.py b/torchchat/model_config/model_config.py index 584a87a74..540804ada 100644 --- a/torchchat/model_config/model_config.py +++ b/torchchat/model_config/model_config.py @@ -46,6 +46,7 @@ class ModelConfig: checkpoint_file: str = field(default="model.pth") tokenizer_file: str = field(default="tokenizer.model") transformer_params_key: str = field(default=None) + prefer_safetensors: bool = field(default=False) # Keys are stored in lowercase.