Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit c333a78

Browse files
committed
Add initial PR for generating and loading state dict
1 parent 7fe2c86 commit c333a78

File tree

2 files changed

+45
-18
lines changed

2 files changed

+45
-18
lines changed

torchchat/cli/builder.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ class BuilderArgs:
6161
dynamic_shapes: bool = False
6262
max_seq_length: Optional[int] = None
6363

64+
quantized_state_path: Optional[Union[Path, str]] = None
65+
6466
def __post_init__(self):
6567
if self.device is None:
6668
self.device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -171,6 +173,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
171173
is_chat_model=is_chat_model,
172174
dynamic_shapes=getattr(args, "dynamic_shapes", False),
173175
max_seq_length=getattr(args, "max_seq_length", None),
176+
quantized_state_path=args.quantized_state_path,
174177
)
175178

176179
@classmethod
@@ -565,25 +568,43 @@ def _initialize_model(
565568
model = _load_model(builder_args)
566569
device_sync(device=builder_args.device)
567570

568-
if quantize:
569-
print(f"Quantizing the model with: {quantize}")
570-
with measure_time("Time to quantize model: {time:.02f} seconds"):
571-
quantize_model(
572-
model,
573-
builder_args.device,
574-
quantize,
575-
tokenizer,
576-
support_tensor_subclass,
577-
)
578-
device_sync(device=builder_args.device)
571+
cache_path = builder_args.quantized_state_path
572+
quant_checkpoint_exists: bool = os.path.isfile(cache_path)
573+
if quantize or quant_checkpoint_exists:
579574

580-
if builder_args.setup_caches:
581-
with torch.device(builder_args.device):
582-
model.setup_caches(
583-
max_batch_size=1,
584-
max_seq_length=max_seq_length
585-
or model.text_transformer_args.max_seq_length,
586-
)
575+
if quantize and quant_checkpoint_exists:
576+
print("WARNING: Both a quantized checkpoint and quantize arg were provided; Ignoring quantize arg")
577+
578+
if quant_checkpoint_exists:
579+
with measure_time("Time to load quantized state: {time:.02f} seconds"):
580+
print(f"Loading the model_state in: {cache_path}")
581+
model.load_state_dict(cache_path)
582+
device_sync(device=builder_args.device)
583+
else:
584+
with measure_time("Time to quantize model: {time:.02f} seconds"):
585+
print(f"Quantizing the model with: {quantize}")
586+
quantize_model(
587+
model,
588+
builder_args.device,
589+
quantize,
590+
tokenizer,
591+
support_tensor_subclass,
592+
)
593+
device_sync(device=builder_args.device)
594+
595+
if cache_path:
596+
with measure_time("Time to save quantized state: {time:.02f} seconds"):
597+
print(f"Saving the quantized state dict")
598+
torch.save(model.state_dict(), cache_path)
599+
600+
601+
if builder_args.setup_caches:
602+
with torch.device(builder_args.device):
603+
model.setup_caches(
604+
max_batch_size=1,
605+
max_seq_length=max_seq_length
606+
or model.text_transformer_args.max_seq_length,
607+
)
587608

588609
model.to(dtype=builder_args.precision)
589610

torchchat/cli/cli.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,12 @@ def _add_model_config_args(parser, verb: str) -> None:
148148
help="Whether to compile the prefill. Improves prefill perf, but has higher compile times.",
149149
)
150150

151+
model_config_parser.add_argument(
152+
"--quantized-state-path",
153+
type=str,
154+
default=None,
155+
help="Quantized state_dict to load (if path exists) or write out to (if path doesn't exist)",
156+
)
151157
model_config_parser.add_argument(
152158
"--dtype",
153159
default="fast",

0 commit comments

Comments
 (0)