Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit a392191

Browse files
committed
fix(convert hf): Better logic to handle multiple weight mapping files
This will not actually be needed for mistral with the fix in download to handle .bin files, but it may be needed for other models, so it's worth having. Branch: download-fix Signed-off-by: Gabe Goodhart <[email protected]>
1 parent de8c92f commit a392191

File tree

1 file changed

+23
-12
lines changed

1 file changed

+23
-12
lines changed

torchchat/cli/convert_hf_checkpoint.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,20 +41,14 @@ def convert_hf_checkpoint(
4141
config = TransformerArgs.from_params(config_args)
4242
print(f"Model config {config.__dict__}")
4343

44-
# Load the json file containing weight mapping
44+
# Find all candidate weight mapping index files
4545
model_map_json_matches = [Path(m) for m in glob.glob(str(model_dir / "*.index.json"))]
46-
if "mistral" not in model_name:
47-
assert len(model_map_json_matches) <= 1, "Found multiple weight mapping files"
48-
if len(model_map_json_matches):
49-
model_map_json = model_map_json_matches[0]
50-
else:
51-
model_map_json = model_dir / "pytorch_model.bin.index.json"
5246

5347
# If there is no weight mapping, check for a consolidated model and
5448
# tokenizer we can move. Llama 2 and Mistral have weight mappings, while
5549
# Llama 3 has a consolidated model and tokenizer.
5650
# Otherwise raise an error.
57-
if not model_map_json.is_file():
51+
if not model_map_json_matches:
5852
consolidated_pth = model_dir / "original" / "consolidated.00.pth"
5953
tokenizer_pth = model_dir / "original" / "tokenizer.model"
6054
if consolidated_pth.is_file() and tokenizer_pth.is_file():
@@ -70,11 +64,29 @@ def convert_hf_checkpoint(
7064
return
7165
else:
7266
raise RuntimeError(
73-
f"Could not find {model_map_json} or {consolidated_pth} plus {tokenizer_pth}"
67+
f"Could not find a valid model weight map or {consolidated_pth} plus {tokenizer_pth}"
7468
)
7569

76-
with open(model_map_json) as json_map:
77-
bin_index = json.load(json_map)
70+
# Load the json file(s) containing weight mapping
71+
#
72+
# NOTE: If there are multiple index files, there are two possibilities:
73+
# 1. The files could be mapped to different weight format files (e.g. .bin
74+
# vs .safetensors)
75+
# 2. The files could be split subsets of the mappings that need to be
76+
# merged
77+
#
78+
# In either case, we can simply keep the mappings where the target file is
79+
# valid in the model dir.
80+
bin_files = {}
81+
for weight_map_file in model_map_json_matches:
82+
with open(weight_map_file, "r") as handle:
83+
weight_map = json.load(handle)
84+
valid_mappings = {
85+
k: model_dir / v
86+
for (k, v) in weight_map.get("weight_map", {}).items()
87+
if (model_dir / v).is_file()
88+
}
89+
bin_files.update(valid_mappings)
7890

7991
weight_map = {
8092
"model.embed_tokens.weight": "tok_embeddings.weight",
@@ -98,7 +110,6 @@ def convert_hf_checkpoint(
98110
"model.norm.weight": "norm.weight",
99111
"lm_head.weight": "output.weight",
100112
}
101-
bin_files = {model_dir / bin for bin in bin_index["weight_map"].values()}
102113

103114
def permute(w, n_heads):
104115
return (

0 commit comments

Comments
 (0)