Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 8c25a9d

Browse files
authored
Merge branch 'main' into stories_browser_fix
2 parents 64a2393 + b217158 commit 8c25a9d

File tree

4 files changed

+91
-11
lines changed

4 files changed

+91
-11
lines changed

dist_run.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import argparse
1212
import os
13+
from enum import auto, Enum
1314
from pathlib import Path
1415
from types import SimpleNamespace
1516
from typing import Any, Dict, List, Optional, Tuple
@@ -49,6 +50,7 @@
4950

5051

5152
logger = SingletonLogger.get_logger()
53+
_tokenizer_type = None # global variable to store the tokenizer type
5254

5355
# Using model name to identify the model to load, for example "llama2-7b-chat".
5456
# You can change it to other values listed below.
@@ -59,6 +61,11 @@
5961
}
6062

6163

64+
class TokenizerType(Enum):
65+
Tiktoken = auto()
66+
SentencePiece = auto()
67+
68+
6269
def _init_distributed():
6370
dist.init_process_group("nccl")
6471
rank = dist.get_rank()
@@ -80,7 +87,10 @@ def _build_chat_tokenizer(
8087
model_name: str,
8188
model_base_name: Optional[str] = None,
8289
) -> SentencePieceProcessor | TiktokenTokenizer:
83-
"""Builds a tokenizer for the given model name."""
90+
"""Builds a tokenizer for the given model name, and sets the global tokenizer type variable"""
91+
92+
global _tokenizer_type
93+
8494
# Try to infer the model base name from the model name:
8595
# e.g. "llama2-7b-chat" -> "llama2"
8696
if model_base_name is None:
@@ -107,6 +117,15 @@ def _build_chat_tokenizer(
107117
logger.info(
108118
f"using tokenizer = {tokenizer.__class__.__module__}.{tokenizer.__class__.__name__}"
109119
)
120+
# set global variable _tokenizer_type
121+
if isinstance(tokenizer, TiktokenTokenizer):
122+
_tokenizer_type = TokenizerType.Tiktoken
123+
elif isinstance(tokenizer, SentencePieceProcessor):
124+
_tokenizer_type = TokenizerType.SentencePiece
125+
else:
126+
raise ValueError(f"Unknown tokenizer type: {tokenizer.__class__}")
127+
128+
logger.info(f"tokenizer type = {_tokenizer_type}")
110129
return tokenizer
111130

112131

@@ -269,6 +288,7 @@ def _cleanup():
269288

270289
prompt = [
271290
"What is Snow?",
291+
# "Can you explain what is the purpose of back propagation in neural networks?",
272292
"Who is Santa Claus?",
273293
"Where does Santa live?",
274294
# "Who is Abraham Lincoln?",
@@ -487,7 +507,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
487507
group=pp_group,
488508
)
489509
# create schedule
490-
decorder = ScheduleGPipe(decode_stage, 1)
510+
decoder = ScheduleGPipe(decode_stage, 1)
491511

492512
# Decoding
493513
with torch.no_grad(), CUDATrackTime() as timer:
@@ -510,11 +530,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
510530

511531
# Run data through pipeline
512532
if pp_rank == first_pp_rank:
513-
output = decorder.step(new_token, **kwargs)
533+
output = decoder.step(new_token, **kwargs)
514534
elif pp_rank == last_pp_rank:
515-
output = decorder.step(**kwargs)
535+
output = decoder.step(**kwargs)
516536
else: # middle pp ranks
517-
decorder.step(**kwargs)
537+
decoder.step(**kwargs)
518538

519539
# Decode the output
520540
if pp_rank == last_pp_rank:
@@ -539,13 +559,16 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
539559
# token ids. Thus cat'ing along dim 1.
540560
res = torch.cat(res, dim=1)
541561
res_list = res.tolist()
542-
if isinstance(tokenizer, TiktokenTokenizer):
562+
if _tokenizer_type == TokenizerType.Tiktoken:
543563
# For TiktokenTokenizer, we need to decode prompt by prompt.
544564
# TODO: is there a better way to do this?
545565
responses = [tokenizer.decode(sequence) for sequence in res_list]
546-
else: # SentencePieceProcessor
566+
elif _tokenizer_type == TokenizerType.SentencePiece: # SentencePieceProcessor
547567
# For SentencePieceProcessor, we can decode the entire 2D list at once.
548568
responses = tokenizer.decode(res_list)
569+
else:
570+
raise ValueError(f"Unknown tokenizer type {_tokenizer_type}")
571+
549572
# Show prompts and responses
550573
for prompt_text, response_text in zip(prompt, responses):
551574
logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}")

torchchat/cli/convert_hf_checkpoint.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import glob
67
import json
78
import os
89
import re
@@ -41,7 +42,12 @@ def convert_hf_checkpoint(
4142
print(f"Model config {config.__dict__}")
4243

4344
# Load the json file containing weight mapping
44-
model_map_json = model_dir / "pytorch_model.bin.index.json"
45+
model_map_json_matches = [Path(m) for m in glob.glob(str(model_dir / "*.index.json"))]
46+
assert len(model_map_json_matches) <= 1, "Found multiple weight mapping files"
47+
if len(model_map_json_matches):
48+
model_map_json = model_map_json_matches[0]
49+
else:
50+
model_map_json = model_dir / "pytorch_model.bin.index.json"
4551

4652
# If there is no weight mapping, check for a consolidated model and
4753
# tokenizer we can move. Llama 2 and Mistral have weight mappings, while
@@ -96,9 +102,33 @@ def permute(w, n_heads):
96102

97103
merged_result = {}
98104
for file in sorted(bin_files):
99-
state_dict = torch.load(
105+
106+
# The state_dict can be loaded from either a torch zip file or
107+
# safetensors. We take our best guess from the name and try all
108+
# possibilities
109+
load_pt_mmap = lambda: torch.load(
100110
str(file), map_location="cpu", mmap=True, weights_only=True
101111
)
112+
load_pt_no_mmap = lambda: torch.load(
113+
str(file), map_location="cpu", mmap=False, weights_only=True
114+
)
115+
def load_safetensors():
116+
import safetensors.torch
117+
with open(file, "rb") as handle:
118+
return safetensors.torch.load(handle.read())
119+
if "safetensors" in str(file):
120+
loaders = [load_safetensors, load_pt_mmap, load_pt_no_mmap]
121+
else:
122+
loaders = [load_pt_mmap, load_pt_no_mmap, load_safetensors]
123+
124+
state_dict = None
125+
for loader in loaders:
126+
try:
127+
state_dict = loader()
128+
break
129+
except Exception:
130+
continue
131+
assert state_dict is not None, f"Unable to load tensors from {file}"
102132
merged_result.update(state_dict)
103133
final_result = {}
104134
for key, value in merged_result.items():

torchchat/cli/download.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,44 @@
2222
def _download_hf_snapshot(
2323
model_config: ModelConfig, artifact_dir: Path, hf_token: Optional[str]
2424
):
25-
from huggingface_hub import snapshot_download
25+
from huggingface_hub import model_info, snapshot_download
2626
from requests.exceptions import HTTPError
2727

2828
# Download and store the HF model artifacts.
2929
print(f"Downloading {model_config.name} from HuggingFace...", file=sys.stderr)
3030
try:
31+
# Fetch the info about the model's repo
32+
model_info = model_info(model_config.distribution_path, token=hf_token)
33+
model_fnames = [f.rfilename for f in model_info.siblings]
34+
35+
# Check the model config for preference between safetensors and pth
36+
has_pth = any(f.endswith(".pth") for f in model_fnames)
37+
has_safetensors = any(f.endswith(".safetensors") for f in model_fnames)
38+
39+
# If told to prefer safetensors, ignore pth files
40+
if model_config.prefer_safetensors:
41+
if not has_safetensors:
42+
print(
43+
f"Model {model_config.name} does not have safetensors files, but prefer_safetensors is set to True. Using pth files instead.",
44+
file=sys.stderr,
45+
)
46+
exit(1)
47+
ignore_patterns = "*.pth"
48+
49+
# If the model has both, prefer pth files over safetensors
50+
elif has_pth and has_safetensors:
51+
ignore_patterns = "*safetensors*"
52+
53+
# Otherwise, download everything
54+
else:
55+
ignore_patterns = None
56+
3157
snapshot_download(
3258
model_config.distribution_path,
3359
local_dir=artifact_dir,
3460
local_dir_use_symlinks=False,
3561
token=hf_token,
36-
ignore_patterns="*safetensors*",
62+
ignore_patterns=ignore_patterns,
3763
)
3864
except HTTPError as e:
3965
if e.response.status_code == 401: # Missing HuggingFace CLI login.

torchchat/model_config/model_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class ModelConfig:
4646
checkpoint_file: str = field(default="model.pth")
4747
tokenizer_file: str = field(default="tokenizer.model")
4848
transformer_params_key: str = field(default=None)
49+
prefer_safetensors: bool = field(default=False)
4950

5051

5152
# Keys are stored in lowercase.

0 commit comments

Comments
 (0)