Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class BuilderArgs:
gguf_kwargs: Optional[Dict[str, Any]] = None
dso_path: Optional[Union[Path, str]] = None
aoti_package_path: Optional[Union[Path, str]] = None
snapshot_path: Optional[Union[Path, str]] = None
pte_path: Optional[Union[Path, str]] = None
device: Optional[str] = None
precision: torch.dtype = torch.float32
Expand Down Expand Up @@ -87,6 +88,7 @@ def __post_init__(self):
or (self.dso_path and Path(self.dso_path).is_file())
or (self.aoti_package_path and Path(self.aoti_package_path).is_file())
or (self.pte_path and Path(self.pte_path).is_file())
or (self.snapshot_path and Path(self.snapshot_path).is_file())
):
raise RuntimeError(
"need to specify a valid checkpoint path, checkpoint dir, gguf path, DSO path, AOTI PACKAGE or PTE path"
Expand Down Expand Up @@ -142,6 +144,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
dso_path = getattr(args, "dso_path", None)
pte_path = getattr(args, "pte_path", None)
aoti_package_path = getattr(args, "aoti_package_path", None)
snapshot_path = getattr(args, "snapshot_path", None)

is_chat_model = False
if args.is_chat_model:
Expand Down Expand Up @@ -169,6 +172,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
output_pte_path = getattr(args, "output_pte_path", None)
output_aoti_package_path = getattr(args, "output_aoti_package_path", None)
output_dso_path = getattr(args, "output_dso_path", None)
output_snapshot_path = getattr(args, "output_snapshot_path", None)
if output_pte_path and args.dtype.startswith("fast"):
if args.dtype == "fast":
# As per Kimish, float32 should be faster on ET XNNPACK
Expand Down Expand Up @@ -206,6 +210,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
dso_path=dso_path,
aoti_package_path=aoti_package_path,
pte_path=pte_path,
snapshot_path=snapshot_path,
device=args.device,
precision=dtype,
setup_caches=(
Expand Down Expand Up @@ -631,6 +636,34 @@ def do_nothing(max_batch_size, max_seq_length):
model = PTEModel(config, builder_args.pte_path)
except Exception:
raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}")
elif builder_args.snapshot_path:
# Resolve ModelArgs for constructing the PTEModel
# If a manual params_path is provided, use that
if builder_args.params_path:
config: ModelArgs = ModelArgs.from_params(builder_args.params_path)
else:
# TODO: Instead of loading the whole model, refactor to call a
# helper that generate just model.config
with measure_time("Time to load model: {time:.02f} seconds"):
model = _load_model(builder_args)
device_sync(device=builder_args.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Does the saved artifact still work if the device has changed? I recall this being an issue with AO (one of the reasons why we didn't add saving earlier)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it might not. Most likely. It depends really on what quantizations are performed and whether they're implemented on the multiple platforms, and in the same way. I.e., if it's the same pytorch/python code, for doing computation with quantized numbers and the same quantization formats are supported, then yes.

If it's a C/C++/CUDA operator, it needs to be supported, with the same name, or with a suitable if/then/else (i.e., don't bake. the "device" setting in)

Quantization weight format layouts need to be consistent, or the loader needs to repack them at load time. (This is totally plausible to do, but I don't think we do that today. I think in the 4b case ""we just know". I tried to change that, but the need/priority wasn't similarly perceived by everybody.)

If it's saved, and reloaded, most (all?) decisions you made are set in stone, like quantization schemes etc (Otherwise, you'd be loading from scratch?). In some sense that's similar to how dso/aoti/pte-output-path / load-dso/aoti/pte-path work, and that's why it's modeled after that export and reload facility. You don't get to change the PTE target on reload, or the device that an aoti model has been compiled for.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think expecting the exporting conditions to be the the same as the executing conditions is a fair start

config = model.config
model = None
try:
model = torch.load(builder_args.snapshot_path, weights_only=False)
except Exception:
raise RuntimeError(f"Failed to load torchchat snapshot {builder_args.snapshot_path}")
# _active_backend() does not allow DSO & AOTI to be true.
# Choose either.
from torchchat.utils.build_utils import set_backend
set_backend (dso=True, pte=False, aoti_package=False)
if (model.config != config):
raise RuntimeError("loaded model architecture mismatch")
##
## import all libraries with custom kernels ans custom operators
## that quantize may be pulling in
##

elif builder_args.distributed:
pp_degree = builder_args.pp
tp_degree = builder_args.tp
Expand Down
14 changes: 13 additions & 1 deletion torchchat/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,12 @@ def _add_export_output_path_args(parser) -> None:
default=None,
help="Output to the specified AOT Inductor .dso model file",
)
exclusive_parser.add_argument(
"--output-snapshot-path",
type=str,
default=None,
help="Output to the specified PyTorch model and sha256 file",
)
exclusive_parser.add_argument(
"--output-aoti-package-path",
type=str,
Expand Down Expand Up @@ -254,7 +260,13 @@ def _add_exported_input_path_args(parser) -> None:
default=None,
help="Use the specified ExecuTorch .pte model file",
)

exclusive_parser.add_argument(
"--snapshot-path",
type=Path,
default=None,
help="Use the specified torchchat snaphot .tc model file",
)


# Add CLI Args related to JIT downloading of model artifacts
def _add_jit_downloading_args(parser) -> None:
Expand Down
49 changes: 48 additions & 1 deletion torchchat/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,31 @@
default_device = "cpu"


"""
Export Snapshot
"""


def export_snapshot(
model: nn.Module,
device: Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused arg?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a cut and paste thing, and wanted to keep args consistent. Mind you, we could put the device in the file or some other such, and check on reload that it's the same. (I think MPS/CPU are sorta fungible, which might help on Macs with quantization when you run out of kernel memory to quantize large models. CPU could use paging to ssd for that. eg discussion on #1483)

output_path: str = "model-snapshot.tc",
) -> str:
"""
Export the model as snapshot.

Args:
model: The model to be exported.
device: The device to run the model on.
output_path: The path to save the exported model.
Returns:
The path to the exported model.
"""
assert output_path.endswith(".tc"), "use .tc extension for snapshots"
torch.save(model, output_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the whole model or can we get away with just the state_dict? https://github.com/pytorch/torchchat/pull/1280/files

That said if we go with the slimmer state_dict, that's dependent on migration to the AO quant that supports this saving

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You rewrite the code as part of the quantization. If you don't save the code, then you must exactly replicate what quantization options you used, create an empty model, quantize it, and then load the state dict over it. Either you transfer the whole responsibility on the user (good luck, you'll die the death of many cuts when users make mistakes and complain that this facility does not work), or you need to save an ungodly amount of information about options used for the original quantization process.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can update the serialization logic as the AO migration finishes (i.e. this PR is good), but I'm not sure that's the case anymore with AO. I was under the impression that the model itself is unaffected and that only the weights are changed

https://github.com/pytorch/ao/blob/48fdd310b3977a0db2ceba37a7725192cd2aafd4/docs/source/serialization.rst#L62

cc: @vmpuri @HDCharles

Copy link
Contributor

@HDCharles HDCharles Jan 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should be able to get by with just the state dict, there are a few apis that don't work that way, but all the subclass ones do. Thats like 75% of the reason that we went with subclasses instead of module swaps.

jack, you're link is a good resource, an alternate reference are our serialization tests to see what is explicitly tests

see e.g. https://github.com/pytorch/ao/blob/48fdd310b3977a0db2ceba37a7725192cd2aafd4/test/integration/test_integration.py#L1322-L1334

https://github.com/pytorch/ao/blob/48fdd310b3977a0db2ceba37a7725192cd2aafd4/test/dtypes/test_affine_quantized.py#L101-L111

return output_path


"""
Export for Server
"""
Expand Down Expand Up @@ -72,6 +97,7 @@ def export_for_server(
"aot_inductor.package": package,
"aot_inductor.metadata": metadata or {},
}

if not package:
options = {"aot_inductor.output_path": output_path}

Expand Down Expand Up @@ -373,14 +399,15 @@ def main(args):

output_pte_path = args.output_pte_path
output_dso_path = args.output_dso_path
output_snapshot_path = args.output_snapshot_path
output_aoti_package_path = args.output_aoti_package_path

if output_pte_path and builder_args.device != "cpu":
print(
f"Warning! ExecuTorch export target is controlled by export recipe, not device setting. Ignoring device={builder_args.device} setting."
)
builder_args.device = "cpu"
elif "mps" in builder_args.device:
elif (output_pte_path or output_dso_path or output_aoti_package_path) and "mps" in builder_args.device:
print("Warning! Device MPS not supported for export. Exporting for device CPU.")
builder_args.device = "cpu"

Expand Down Expand Up @@ -417,6 +444,7 @@ def main(args):
model_to_pte = model
model_to_dso = model
model_to_aoti_package = model
model_to_snapshot = model
else:
if output_pte_path:
_set_gguf_kwargs(builder_args, is_et=True, context="export")
Expand All @@ -436,6 +464,15 @@ def main(args):
model_to_dso = model_to_aoti_package
_unset_gguf_kwargs(builder_args)

if output_snapshot_path:
_set_gguf_kwargs(builder_args, is_et=False, context="export")
model_to_snapshot = _initialize_model(
builder_args,
quantize,
support_tensor_subclass=False,
)
_unset_gguf_kwargs(builder_args)

with torch.no_grad():
if output_pte_path:
output_pte_path = str(os.path.abspath(output_pte_path))
Expand Down Expand Up @@ -483,3 +520,13 @@ def main(args):
package=True,
metadata=metadata,
)

if output_snapshot_path:
output_snapshot_path = str(os.path.abspath(output_snapshot_path))
print(f"Exporting model using Snapshot to {output_snapshot_path}")
export_snapshot(
model_to_snapshot,
builder_args.device,
output_snapshot_path,
)

Loading