Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 9e1f421

Browse files
authored
Merge branch 'main' into bump104
2 parents 6f081ce + 6a2a2e8 commit 9e1f421

File tree

11 files changed

+478
-293
lines changed

11 files changed

+478
-293
lines changed

dist_run.py

Lines changed: 66 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import argparse
1212
import os
13+
from enum import auto, Enum
1314
from pathlib import Path
1415
from types import SimpleNamespace
1516
from typing import Any, Dict, List, Optional, Tuple
@@ -22,10 +23,10 @@
2223
from torchchat.distributed.logging_utils import SingletonLogger
2324

2425
# TODO - these are not distributed specific, consider moving to new package
25-
from torchchat.distributed.safetensor_utils import (
26+
from torchchat.distributed.checkpoint_utils import (
2627
get_hf_config_file,
27-
get_hf_weight_map_and_path,
28-
load_safetensor_weights,
28+
load_weights_from_hf_format,
29+
load_weights_from_torchchat_format,
2930
)
3031
from torchchat.distributed.utils import (
3132
bytes_to_readable,
@@ -49,6 +50,7 @@
4950

5051

5152
logger = SingletonLogger.get_logger()
53+
_tokenizer_type = None # global variable to store the tokenizer type
5254

5355
# Using model name to identify the model to load, for example "llama2-7b-chat".
5456
# You can change it to other values listed below.
@@ -59,6 +61,11 @@
5961
}
6062

6163

64+
class TokenizerType(Enum):
65+
Tiktoken = auto()
66+
SentencePiece = auto()
67+
68+
6269
def _init_distributed():
6370
dist.init_process_group("nccl")
6471
rank = dist.get_rank()
@@ -80,7 +87,10 @@ def _build_chat_tokenizer(
8087
model_name: str,
8188
model_base_name: Optional[str] = None,
8289
) -> SentencePieceProcessor | TiktokenTokenizer:
83-
"""Builds a tokenizer for the given model name."""
90+
"""Builds a tokenizer for the given model name, and sets the global tokenizer type variable"""
91+
92+
global _tokenizer_type
93+
8494
# Try to infer the model base name from the model name:
8595
# e.g. "llama2-7b-chat" -> "llama2"
8696
if model_base_name is None:
@@ -107,29 +117,45 @@ def _build_chat_tokenizer(
107117
logger.info(
108118
f"using tokenizer = {tokenizer.__class__.__module__}.{tokenizer.__class__.__name__}"
109119
)
120+
# set global variable _tokenizer_type
121+
if isinstance(tokenizer, TiktokenTokenizer):
122+
_tokenizer_type = TokenizerType.Tiktoken
123+
elif isinstance(tokenizer, SentencePieceProcessor):
124+
_tokenizer_type = TokenizerType.SentencePiece
125+
else:
126+
raise ValueError(f"Unknown tokenizer type: {tokenizer.__class__}")
127+
128+
logger.info(f"tokenizer type = {_tokenizer_type}")
110129
return tokenizer
111130

112131

113-
def _load_model_weights(stage_module, distribution, device, model_config):
132+
def _load_model_weights(
133+
stage_module: torch.nn.Module,
134+
distribution: str,
135+
device: torch.device,
136+
model_config: ModelArgs,
137+
chpt_from: str,
138+
):
114139
"""Load the weights from the safetensor file(s) into the model stage.
115140
Model config is needed b/c we permute wq and wk weights based on attn heads.
116-
"""
117141
118-
weight_map, weight_path, key_map = get_hf_weight_map_and_path(distribution)
119-
120-
num_loaded_weights, num_missing_weights = load_safetensor_weights(
121-
stage_module,
122-
weight_map,
123-
weight_path,
124-
key_map,
125-
device,
126-
model_config=model_config,
127-
)
128-
logger.info(
129-
f"Success - Loaded {num_loaded_weights} weights, {num_missing_weights} missing weights"
130-
)
131-
if num_missing_weights > 0:
132-
raise ValueError(f"Missing {num_missing_weights} weights")
142+
Args:
143+
stage_module (torch.nn.Module): The model stage to load the weights into.
144+
distribution (str): The distribution name, e.g. "meta-llama/Meta-Llama-3-8B-Instruct".
145+
device (torch.device): The device to load the weights onto.
146+
model_config (ModelArgs): The model config.
147+
chpt_from (str): The checkpoint format to load the weights from, e.g. "torchchat" or "hf".
148+
"""
149+
if chpt_from == "hf":
150+
# This format stands for: index file + multiple binary files
151+
load_weights_from_hf_format(stage_module, distribution, device, model_config)
152+
elif chpt_from == "torchchat":
153+
# This format stands for:
154+
# single binary file, OR
155+
# multiple binary files without index files.
156+
load_weights_from_torchchat_format(stage_module, distribution, device, model_config)
157+
else:
158+
raise ValueError(f"Unknown checkpoint format: {chpt_from}")
133159

134160

135161
def _encode_strings(
@@ -269,6 +295,7 @@ def _cleanup():
269295

270296
prompt = [
271297
"What is Snow?",
298+
# "Can you explain what is the purpose of back propagation in neural networks?",
272299
"Who is Santa Claus?",
273300
"Where does Santa live?",
274301
# "Who is Abraham Lincoln?",
@@ -286,7 +313,7 @@ def main(args):
286313
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}")
287314

288315
distribution, model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE[model_name]
289-
logger.info(f"Using HF model weights from {distribution} and dtype {model_dtype}")
316+
logger.info(f"Using model weights from {distribution} and dtype {model_dtype}")
290317

291318
# Model-level config
292319
model_config = ModelArgs.from_name(distribution)
@@ -348,7 +375,7 @@ def main(args):
348375
# Load weights
349376
logger.info(f"Loading weights for {pp_rank=} on {device=}")
350377
with CUDATrackTime() as timer:
351-
_load_model_weights(model, distribution, device=device, model_config=config)
378+
_load_model_weights(model, distribution, device, config, args.chpt_from)
352379

353380
logger.info(
354381
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
@@ -487,7 +514,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
487514
group=pp_group,
488515
)
489516
# create schedule
490-
decorder = ScheduleGPipe(decode_stage, 1)
517+
decoder = ScheduleGPipe(decode_stage, 1)
491518

492519
# Decoding
493520
with torch.no_grad(), CUDATrackTime() as timer:
@@ -510,11 +537,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
510537

511538
# Run data through pipeline
512539
if pp_rank == first_pp_rank:
513-
output = decorder.step(new_token, **kwargs)
540+
output = decoder.step(new_token, **kwargs)
514541
elif pp_rank == last_pp_rank:
515-
output = decorder.step(**kwargs)
542+
output = decoder.step(**kwargs)
516543
else: # middle pp ranks
517-
decorder.step(**kwargs)
544+
decoder.step(**kwargs)
518545

519546
# Decode the output
520547
if pp_rank == last_pp_rank:
@@ -539,13 +566,16 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
539566
# token ids. Thus cat'ing along dim 1.
540567
res = torch.cat(res, dim=1)
541568
res_list = res.tolist()
542-
if isinstance(tokenizer, TiktokenTokenizer):
569+
if _tokenizer_type == TokenizerType.Tiktoken:
543570
# For TiktokenTokenizer, we need to decode prompt by prompt.
544571
# TODO: is there a better way to do this?
545572
responses = [tokenizer.decode(sequence) for sequence in res_list]
546-
else: # SentencePieceProcessor
573+
elif _tokenizer_type == TokenizerType.SentencePiece: # SentencePieceProcessor
547574
# For SentencePieceProcessor, we can decode the entire 2D list at once.
548575
responses = tokenizer.decode(res_list)
576+
else:
577+
raise ValueError(f"Unknown tokenizer type {_tokenizer_type}")
578+
549579
# Show prompts and responses
550580
for prompt_text, response_text in zip(prompt, responses):
551581
logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}")
@@ -579,6 +609,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
579609
default=False,
580610
help="Whether to decode token into string in flight",
581611
)
612+
parser.add_argument(
613+
"--chpt-from",
614+
type=str,
615+
default="hf", # TODO: change to torchchat once we support it well
616+
help="Checkpoint format to load from",
617+
choices=["hf", "torchchat"],
618+
)
582619
args = parser.parse_args()
583620

584621
main(args)

install/install_requirements.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ TUNE_NIGHTLY_VERSION=dev20240928
6767
if [[ -x "$(command -v nvidia-smi)" ]];
6868
then
6969
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/cu121"
70+
elif [[ -x "$(command -v rocminfo)" ]];
71+
then
72+
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/rocm6.2"
7073
else
7174
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/cpu"
7275
fi

torchchat/cli/builder.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -335,11 +335,7 @@ def _load_model_gguf(builder_args: BuilderArgs) -> Model:
335335
return model
336336

337337

338-
def _load_model_default(builder_args: BuilderArgs) -> Model:
339-
assert not builder_args.gguf_path
340-
341-
model: Model = _init_model_on_meta_device(builder_args)
342-
338+
def _load_checkpoint(builder_args: BuilderArgs):
343339
if builder_args.params_table and builder_args.params_table.endswith("Tune"):
344340
print("Loading Tune checkpoint")
345341
meta_checkpoint = torch.load(
@@ -377,6 +373,16 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
377373
mmap=True,
378374
weights_only=True,
379375
)
376+
return checkpoint
377+
378+
379+
def _load_model_default(builder_args: BuilderArgs) -> Model:
380+
assert not builder_args.gguf_path
381+
382+
model: Model = _init_model_on_meta_device(builder_args)
383+
384+
# Load checkpoint from filesystem
385+
checkpoint = _load_checkpoint(builder_args)
380386

381387
if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path):
382388
checkpoint = checkpoint["model"]

torchchat/cli/convert_hf_checkpoint.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,17 @@ def convert_hf_checkpoint(
8181
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
8282
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
8383
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
84+
"model.layers.{}.self_attn.q_proj.bias": "layers.{}.attention.wq.bias",
85+
"model.layers.{}.self_attn.k_proj.bias": "layers.{}.attention.wk.bias",
86+
"model.layers.{}.self_attn.v_proj.bias": "layers.{}.attention.wv.bias",
87+
"model.layers.{}.self_attn.o_proj.bias": "layers.{}.attention.wo.bias",
8488
"model.layers.{}.self_attn.rotary_emb.inv_freq": None,
8589
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
8690
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
8791
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
92+
"model.layers.{}.mlp.gate_proj.bias": "layers.{}.feed_forward.w1.bias",
93+
"model.layers.{}.mlp.up_proj.bias": "layers.{}.feed_forward.w3.bias",
94+
"model.layers.{}.mlp.down_proj.bias": "layers.{}.feed_forward.w2.bias",
8895
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
8996
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
9097
"model.norm.weight": "norm.weight",
@@ -93,11 +100,10 @@ def convert_hf_checkpoint(
93100
bin_files = {model_dir / bin for bin in bin_index["weight_map"].values()}
94101

95102
def permute(w, n_heads):
96-
dim = config.dim
97103
return (
98-
w.view(n_heads, 2, config.head_dim // 2, dim)
104+
w.view(n_heads, 2, config.head_dim // 2, *w.shape[1:])
99105
.transpose(1, 2)
100-
.reshape(config.head_dim * n_heads, dim)
106+
.reshape(w.shape)
101107
)
102108

103109
merged_result = {}
@@ -130,6 +136,7 @@ def load_safetensors():
130136
continue
131137
assert state_dict is not None, f"Unable to load tensors from {file}"
132138
merged_result.update(state_dict)
139+
133140
final_result = {}
134141
for key, value in merged_result.items():
135142
if "layers" in key:
@@ -145,16 +152,18 @@ def load_safetensors():
145152
final_result[new_key] = value
146153

147154
for key in tuple(final_result.keys()):
148-
if "wq" in key:
155+
if "wq.weight" in key or "wq.bias" in key:
156+
wk_key = key.replace("wq", "wk")
157+
wv_key = key.replace("wq", "wv")
149158
q = final_result[key]
150-
k = final_result[key.replace("wq", "wk")]
151-
v = final_result[key.replace("wq", "wv")]
159+
k = final_result[wk_key]
160+
v = final_result[wv_key]
152161
q = permute(q, config.n_heads)
153162
k = permute(k, config.n_local_heads)
154163
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
155164
del final_result[key]
156-
del final_result[key.replace("wq", "wk")]
157-
del final_result[key.replace("wq", "wv")]
165+
del final_result[wk_key]
166+
del final_result[wv_key]
158167
print(f"Saving checkpoint to {model_dir / 'model.pth'}. This may take a while.")
159168
torch.save(final_result, model_dir / "model.pth")
160169
print("Done.")

0 commit comments

Comments
 (0)