Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 69b17e3

Browse files
committed
support model snapshots to save quantized models
1 parent 654bb03 commit 69b17e3

File tree

3 files changed

+93
-3
lines changed

3 files changed

+93
-3
lines changed

torchchat/cli/builder.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class BuilderArgs:
5656
gguf_kwargs: Optional[Dict[str, Any]] = None
5757
dso_path: Optional[Union[Path, str]] = None
5858
aoti_package_path: Optional[Union[Path, str]] = None
59+
snapshot_path: Optional[Union[Path, str]] = None
5960
pte_path: Optional[Union[Path, str]] = None
6061
device: Optional[str] = None
6162
precision: torch.dtype = torch.float32
@@ -81,6 +82,7 @@ def __post_init__(self):
8182
or (self.dso_path and Path(self.dso_path).is_file())
8283
or (self.aoti_package_path and Path(self.aoti_package_path).is_file())
8384
or (self.pte_path and Path(self.pte_path).is_file())
85+
or (self.snapshot_path and Path(self.snapshot_path).is_file())
8486
):
8587
raise RuntimeError(
8688
"need to specify a valid checkpoint path, checkpoint dir, gguf path, DSO path, AOTI PACKAGE or PTE path"
@@ -136,6 +138,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
136138
dso_path = getattr(args, "dso_path", None)
137139
pte_path = getattr(args, "pte_path", None)
138140
aoti_package_path = getattr(args, "aoti_package_path", None)
141+
snapshot_path = getattr(args, "snapshot_path", None)
139142

140143
is_chat_model = False
141144
if args.is_chat_model:
@@ -163,6 +166,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
163166
output_pte_path = getattr(args, "output_pte_path", None)
164167
output_aoti_package_path = getattr(args, "output_aoti_package_path", None)
165168
output_dso_path = getattr(args, "output_dso_path", None)
169+
output_snapshot_path = getattr(args, "output_snapshot_path", None)
166170
if output_pte_path and args.dtype.startswith("fast"):
167171
if args.dtype == "fast":
168172
# As per Kimish, float32 should be faster on ET XNNPACK
@@ -189,6 +193,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
189193
dso_path=dso_path,
190194
aoti_package_path=aoti_package_path,
191195
pte_path=pte_path,
196+
snapshot_path=snapshot_path,
192197
device=args.device,
193198
precision=dtype,
194199
setup_caches=(
@@ -614,6 +619,33 @@ def do_nothing(max_batch_size, max_seq_length):
614619
model = PTEModel(config, builder_args.pte_path)
615620
except Exception:
616621
raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}")
622+
elif builder_args.snapshot_path:
623+
# Resolve ModelArgs for constructing the PTEModel
624+
# If a manual params_path is provided, use that
625+
if builder_args.params_path:
626+
config: ModelArgs = ModelArgs.from_params(builder_args.params_path)
627+
else:
628+
# TODO: Instead of loading the whole model, refactor to call a
629+
# helper that generate just model.config
630+
with measure_time("Time to load model: {time:.02f} seconds"):
631+
model = _load_model(builder_args)
632+
device_sync(device=builder_args.device)
633+
config = model.config
634+
model = None
635+
try:
636+
model = torch.load(builder_args.snapshot_path, weights_only=False)
637+
except Exception:
638+
raise RuntimeError(f"Failed to load torchchat snapshot {builder_args.snapshot_path}")
639+
# _active_backend() does not allow DSO & AOTI to be true.
640+
# Choose either.
641+
set_backend (dso=True, pte=False, aoti_package=False)
642+
if (model.config != config):
643+
raise RuntimeError("loaded model architecture mismatch")
644+
##
645+
## import all libraries with custom kernels ans custom operators
646+
## that quantize may be pulling in
647+
##
648+
617649
elif builder_args.distributed:
618650
pp_degree = builder_args.pp
619651
tp_degree = builder_args.tp

torchchat/cli/cli.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,12 @@ def _add_export_output_path_args(parser) -> None:
200200
default=None,
201201
help="Output to the specified AOT Inductor .dso model file",
202202
)
203+
exclusive_parser.add_argument(
204+
"--output-snapshot-path",
205+
type=str,
206+
default=None,
207+
help="Output to the specified PyTorch model and sha256 file",
208+
)
203209
exclusive_parser.add_argument(
204210
"--output-aoti-package-path",
205211
type=str,
@@ -247,7 +253,13 @@ def _add_exported_input_path_args(parser) -> None:
247253
default=None,
248254
help="Use the specified ExecuTorch .pte model file",
249255
)
250-
256+
exclusive_parser.add_argument(
257+
"--snapshot-path",
258+
type=Path,
259+
default=None,
260+
help="Use the specified torchchat snaphot .tc model file",
261+
)
262+
251263

252264
# Add CLI Args related to JIT downloading of model artifacts
253265
def _add_jit_downloading_args(parser) -> None:

torchchat/export.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,31 @@
2828
default_device = "cpu"
2929

3030

31+
"""
32+
Export Snapshot
33+
"""
34+
35+
36+
def export_snapshot(
37+
model: nn.Module,
38+
device: Optional[str] = None,
39+
output_path: str = "model-snapshot.tc",
40+
) -> str:
41+
"""
42+
Export the model as snapshot.
43+
44+
Args:
45+
model: The model to be exported.
46+
device: The device to run the model on.
47+
output_path: The path to save the exported model.
48+
Returns:
49+
The path to the exported model.
50+
"""
51+
assert output_path.endswith(".tc"), "use .tc extension for snapshots"
52+
torch.save(model, output_path)
53+
return output_path
54+
55+
3156
"""
3257
Export for Server
3358
"""
@@ -66,7 +91,7 @@ def export_for_server(
6691
)
6792
dynamic_shapes = None
6893

69-
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
94+
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.FLASH_ATTENTION ]):
7095
metadata = {} # TODO: put more metadata here
7196
options = {"aot_inductor.package": package, "aot_inductor.metadata": metadata}
7297
if not package:
@@ -359,14 +384,15 @@ def main(args):
359384

360385
output_pte_path = args.output_pte_path
361386
output_dso_path = args.output_dso_path
387+
output_snapshot_path = args.output_snapshot_path
362388
output_aoti_package_path = args.output_aoti_package_path
363389

364390
if output_pte_path and builder_args.device != "cpu":
365391
print(
366392
f"Warning! ExecuTorch export target is controlled by export recipe, not device setting. Ignoring device={builder_args.device} setting."
367393
)
368394
builder_args.device = "cpu"
369-
elif "mps" in builder_args.device:
395+
elif (output_pte_path or output_dso_path or output_aoti_package_path) and "mps" in builder_args.device:
370396
print("Warning! Device MPS not supported for export. Exporting for device CPU.")
371397
builder_args.device = "cpu"
372398

@@ -402,6 +428,7 @@ def main(args):
402428
model_to_pte = model
403429
model_to_dso = model
404430
model_to_aoti_package = model
431+
model_to_snapshot = model
405432
else:
406433
if output_pte_path:
407434
_set_gguf_kwargs(builder_args, is_et=True, context="export")
@@ -421,6 +448,15 @@ def main(args):
421448
model_to_dso = model_to_aoti_package
422449
_unset_gguf_kwargs(builder_args)
423450

451+
if output_snapshot_path:
452+
_set_gguf_kwargs(builder_args, is_et=False, context="export")
453+
model_to_snapshot = _initialize_model(
454+
builder_args,
455+
quantize,
456+
support_tensor_subclass=False,
457+
)
458+
_unset_gguf_kwargs(builder_args)
459+
424460
with torch.no_grad():
425461
if output_pte_path:
426462
output_pte_path = str(os.path.abspath(output_pte_path))
@@ -454,3 +490,13 @@ def main(args):
454490
builder_args.dynamic_shapes,
455491
package=True,
456492
)
493+
494+
if output_snapshot_path:
495+
output_snapshot_path = str(os.path.abspath(output_snapshot_path))
496+
print(f"Exporting model using Snapshot to {output_snapshot_path}")
497+
export_snapshot(
498+
model_to_snapshot,
499+
builder_args.device,
500+
output_snapshot_path,
501+
)
502+

0 commit comments

Comments
 (0)