Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 763a9ce

Browse files
committed
Generalize to any state_dict
1 parent f3fbc91 commit 763a9ce

File tree

2 files changed

+38
-32
lines changed

2 files changed

+38
-32
lines changed

torchchat/cli/builder.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,12 @@
1616
import torch._inductor.config
1717
import torch.nn as nn
1818

19-
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune
20-
21-
from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama
22-
2319
from torch.distributed.device_mesh import DeviceMesh
2420

25-
from torchtune.models.convert_weights import meta_to_tune
26-
27-
from torchtune.training import set_default_dtype
21+
from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama
2822

2923
from torchchat.model import Model, ModelArgs, ModelType
3024

31-
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
32-
3325
from torchchat.model_config.model_config import resolve_model_config
3426
from torchchat.utils.build_utils import (
3527
device_sync,
@@ -40,6 +32,14 @@
4032
from torchchat.utils.measure_time import measure_time
4133
from torchchat.utils.quantize import quantize_model
4234

35+
from torchtune.models.convert_weights import meta_to_tune
36+
37+
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
38+
39+
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune
40+
41+
from torchtune.training import set_default_dtype
42+
4343

4444
@dataclass
4545
class BuilderArgs:
@@ -61,7 +61,7 @@ class BuilderArgs:
6161
dynamic_shapes: bool = False
6262
max_seq_length: Optional[int] = None
6363

64-
quantized_state_path: Optional[Union[Path, str]] = None
64+
state_dict_path: Optional[Union[Path, str]] = None
6565

6666
def __post_init__(self):
6767
if self.device is None:
@@ -89,7 +89,9 @@ def __post_init__(self):
8989
]
9090
for param, param_msg in ignored_params:
9191
if param:
92-
print(f"Warning: {param_msg} ignored because an exported DSO or PTE path was specified")
92+
print(
93+
f"Warning: {param_msg} ignored because an exported DSO or PTE path was specified"
94+
)
9395
else:
9496
self.prefill_possible = True
9597

@@ -173,7 +175,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
173175
is_chat_model=is_chat_model,
174176
dynamic_shapes=getattr(args, "dynamic_shapes", False),
175177
max_seq_length=getattr(args, "max_seq_length", None),
176-
quantized_state_path=args.quantized_state_path,
178+
state_dict_path=args.state_dict_path,
177179
)
178180

179181
@classmethod
@@ -400,10 +402,10 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
400402
# does not host any actual values, need to reinitialize them in the actual
401403
# device. Only do those buffer initialization, without initializing the entire
402404
# model.
403-
decoder_config = model.config.transformer_args['decoder']
404-
head_dim = decoder_config['embed_dim'] // decoder_config['num_heads']
405-
max_seq_len = decoder_config['max_seq_len']
406-
rope_base = decoder_config['rope_base']
405+
decoder_config = model.config.transformer_args["decoder"]
406+
head_dim = decoder_config["embed_dim"] // decoder_config["num_heads"]
407+
max_seq_len = decoder_config["max_seq_len"]
408+
rope_base = decoder_config["rope_base"]
407409
for submodule in model.modules():
408410
if isinstance(submodule, Llama3ScaledRoPE):
409411
submodule.__init__(head_dim, max_seq_len, rope_base)
@@ -491,6 +493,7 @@ def _load_model(builder_args: BuilderArgs) -> Model:
491493
model = model.to(device=builder_args.device, dtype=builder_args.precision)
492494
return model.eval()
493495

496+
494497
def _initialize_model(
495498
builder_args: BuilderArgs,
496499
quantize,
@@ -568,17 +571,19 @@ def _initialize_model(
568571
model = _load_model(builder_args)
569572
device_sync(device=builder_args.device)
570573

571-
cache_path = builder_args.quantized_state_path
572-
quant_checkpoint_exists: bool = cache_path and os.path.isfile(cache_path)
573-
if quantize or quant_checkpoint_exists:
574+
state_dict_path = builder_args.state_dict_path
575+
state_dict_exists: bool = state_dict_path and os.path.isfile(state_dict_path)
576+
if quantize or state_dict_exists:
574577

575-
if quantize and quant_checkpoint_exists:
576-
print("WARNING: Both a quantized checkpoint and quantize arg were provided; Ignoring quantize arg")
578+
if quantize and state_dict_exists:
579+
print(
580+
"WARNING: Both a state_dict and quantize arg were provided; Ignoring quantize arg"
581+
)
577582

578-
if quant_checkpoint_exists:
583+
if state_dict_exists:
579584
with measure_time("Time to load quantized state: {time:.02f} seconds"):
580-
print(f"Loading the model_state in: {cache_path}")
581-
model.load_state_dict(cache_path)
585+
print(f"Loading the model_state in: {state_dict_path}")
586+
model.load_state_dict(state_dict_path)
582587
device_sync(device=builder_args.device)
583588
else:
584589
with measure_time("Time to quantize model: {time:.02f} seconds"):
@@ -592,11 +597,12 @@ def _initialize_model(
592597
)
593598
device_sync(device=builder_args.device)
594599

595-
if cache_path:
596-
with measure_time("Time to save quantized state: {time:.02f} seconds"):
600+
if state_dict_path:
601+
with measure_time(
602+
"Time to save quantized state: {time:.02f} seconds"
603+
):
597604
print(f"Saving the quantized state dict")
598-
torch.save(model.state_dict(), cache_path)
599-
605+
torch.save(model.state_dict(), state_dict_path)
600606

601607
if builder_args.setup_caches:
602608
with torch.device(builder_args.device):

torchchat/cli/cli.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,10 @@ def _add_model_config_args(parser, verb: str) -> None:
149149
)
150150

151151
model_config_parser.add_argument(
152-
"--quantized-state-path",
152+
"--state-dict-path",
153153
type=str,
154154
default=None,
155-
help="Quantized state_dict to load (if path exists) or write out to (if path doesn't exist)",
155+
help="Model state dict to load (if path exists) or write out to (if path doesn't exist). Supercedes --quantize arg.",
156156
)
157157
model_config_parser.add_argument(
158158
"--dtype",
@@ -431,13 +431,13 @@ def _add_custom_model_args(parser) -> None:
431431
"--params-path",
432432
type=Path,
433433
default=None,
434-
help= "Use the specified parameter file, instead of one specified under torchchat.model_params",
434+
help="Use the specified parameter file, instead of one specified under torchchat.model_params",
435435
)
436436
parser.add_argument(
437437
"--tokenizer-path",
438438
type=Path,
439439
default=None,
440-
help= "Use the specified model tokenizer file, instead of the one downloaded from HuggingFace",
440+
help="Use the specified model tokenizer file, instead of the one downloaded from HuggingFace",
441441
)
442442

443443

0 commit comments

Comments
 (0)