-
Notifications
You must be signed in to change notification settings - Fork 248
Add export --output-snapshot-path snap.tc, and --snapshot-path snap.tc
#1465
Changes from all commits
69b17e3
0ae743a
716c4c0
e29373d
bbf1fa0
2798666
af91f04
8bebe90
da16f6a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,6 +28,31 @@ | |
| default_device = "cpu" | ||
|
|
||
|
|
||
| """ | ||
| Export Snapshot | ||
| """ | ||
|
|
||
|
|
||
| def export_snapshot( | ||
| model: nn.Module, | ||
| device: Optional[str] = None, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unused arg? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a cut and paste thing, and wanted to keep args consistent. Mind you, we could put the device in the file or some other such, and check on reload that it's the same. (I think MPS/CPU are sorta fungible, which might help on Macs with quantization when you run out of kernel memory to quantize large models. CPU could use paging to ssd for that. eg discussion on #1483) |
||
| output_path: str = "model-snapshot.tc", | ||
| ) -> str: | ||
| """ | ||
| Export the model as snapshot. | ||
|
|
||
| Args: | ||
| model: The model to be exported. | ||
| device: The device to run the model on. | ||
| output_path: The path to save the exported model. | ||
| Returns: | ||
| The path to the exported model. | ||
| """ | ||
| assert output_path.endswith(".tc"), "use .tc extension for snapshots" | ||
| torch.save(model, output_path) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need the whole model or can we get away with just the state_dict? https://github.com/pytorch/torchchat/pull/1280/files That said if we go with the slimmer state_dict, that's dependent on migration to the AO quant that supports this saving There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You rewrite the code as part of the quantization. If you don't save the code, then you must exactly replicate what quantization options you used, create an empty model, quantize it, and then load the state dict over it. Either you transfer the whole responsibility on the user (good luck, you'll die the death of many cuts when users make mistakes and complain that this facility does not work), or you need to save an ungodly amount of information about options used for the original quantization process. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can update the serialization logic as the AO migration finishes (i.e. this PR is good), but I'm not sure that's the case anymore with AO. I was under the impression that the model itself is unaffected and that only the weights are changed cc: @vmpuri @HDCharles There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you should be able to get by with just the state dict, there are a few apis that don't work that way, but all the subclass ones do. Thats like 75% of the reason that we went with subclasses instead of module swaps. jack, you're link is a good resource, an alternate reference are our serialization tests to see what is explicitly tests |
||
| return output_path | ||
|
|
||
|
|
||
| """ | ||
| Export for Server | ||
| """ | ||
|
|
@@ -72,6 +97,7 @@ def export_for_server( | |
| "aot_inductor.package": package, | ||
| "aot_inductor.metadata": metadata or {}, | ||
| } | ||
|
|
||
| if not package: | ||
| options = {"aot_inductor.output_path": output_path} | ||
|
|
||
|
|
@@ -373,14 +399,15 @@ def main(args): | |
|
|
||
| output_pte_path = args.output_pte_path | ||
| output_dso_path = args.output_dso_path | ||
| output_snapshot_path = args.output_snapshot_path | ||
| output_aoti_package_path = args.output_aoti_package_path | ||
|
|
||
| if output_pte_path and builder_args.device != "cpu": | ||
| print( | ||
| f"Warning! ExecuTorch export target is controlled by export recipe, not device setting. Ignoring device={builder_args.device} setting." | ||
| ) | ||
| builder_args.device = "cpu" | ||
| elif "mps" in builder_args.device: | ||
| elif (output_pte_path or output_dso_path or output_aoti_package_path) and "mps" in builder_args.device: | ||
Jack-Khuu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| print("Warning! Device MPS not supported for export. Exporting for device CPU.") | ||
| builder_args.device = "cpu" | ||
|
|
||
|
|
@@ -417,6 +444,7 @@ def main(args): | |
| model_to_pte = model | ||
| model_to_dso = model | ||
| model_to_aoti_package = model | ||
| model_to_snapshot = model | ||
| else: | ||
| if output_pte_path: | ||
| _set_gguf_kwargs(builder_args, is_et=True, context="export") | ||
|
|
@@ -436,6 +464,15 @@ def main(args): | |
| model_to_dso = model_to_aoti_package | ||
| _unset_gguf_kwargs(builder_args) | ||
|
|
||
| if output_snapshot_path: | ||
| _set_gguf_kwargs(builder_args, is_et=False, context="export") | ||
| model_to_snapshot = _initialize_model( | ||
| builder_args, | ||
| quantize, | ||
| support_tensor_subclass=False, | ||
| ) | ||
| _unset_gguf_kwargs(builder_args) | ||
|
|
||
| with torch.no_grad(): | ||
| if output_pte_path: | ||
| output_pte_path = str(os.path.abspath(output_pte_path)) | ||
|
|
@@ -483,3 +520,13 @@ def main(args): | |
| package=True, | ||
| metadata=metadata, | ||
| ) | ||
|
|
||
| if output_snapshot_path: | ||
| output_snapshot_path = str(os.path.abspath(output_snapshot_path)) | ||
| print(f"Exporting model using Snapshot to {output_snapshot_path}") | ||
| export_snapshot( | ||
| model_to_snapshot, | ||
| builder_args.device, | ||
| output_snapshot_path, | ||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Q: Does the saved artifact still work if the device has changed? I recall this being an issue with AO (one of the reasons why we didn't add saving earlier)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No it might not. Most likely. It depends really on what quantizations are performed and whether they're implemented on the multiple platforms, and in the same way. I.e., if it's the same pytorch/python code, for doing computation with quantized numbers and the same quantization formats are supported, then yes.
If it's a C/C++/CUDA operator, it needs to be supported, with the same name, or with a suitable if/then/else (i.e., don't bake. the "device" setting in)
Quantization weight format layouts need to be consistent, or the loader needs to repack them at load time. (This is totally plausible to do, but I don't think we do that today. I think in the 4b case ""we just know". I tried to change that, but the need/priority wasn't similarly perceived by everybody.)
If it's saved, and reloaded, most (all?) decisions you made are set in stone, like quantization schemes etc (Otherwise, you'd be loading from scratch?). In some sense that's similar to how dso/aoti/pte-output-path / load-dso/aoti/pte-path work, and that's why it's modeled after that export and reload facility. You don't get to change the PTE target on reload, or the device that an aoti model has been compiled for.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think expecting the exporting conditions to be the the same as the executing conditions is a fair start