Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit ea41439

Browse files
committed
[Not for land] Util for saving quantized model
1 parent 9fb7999 commit ea41439

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

torchchat.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
"where": "Return directory containing downloaded model artifacts",
5050
"server": "[WIP] Starts a locally hosted REST server for model interaction",
5151
"eval": "Evaluate a model via lm-eval",
52+
"save_quant": "Quantize a model and save it to disk",
5253
}
5354
for verb, description in VERB_HELP.items():
5455
subparser = subparsers.add_parser(verb, help=description)
@@ -115,5 +116,9 @@
115116
from torchchat.cli.download import remove_main
116117

117118
remove_main(args)
119+
elif args.command == "save_quant":
120+
from torchchat.save_quant import main as save_quant_main
121+
122+
save_quant_main(args)
118123
else:
119124
parser.print_help()

torchchat/save_quant.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
from pathlib import Path
9+
from typing import Optional
10+
11+
import torch
12+
import torch.nn as nn
13+
14+
from torchchat.cli.builder import (
15+
_initialize_model,
16+
BuilderArgs,
17+
)
18+
19+
from torchchat.utils.build_utils import set_precision
20+
21+
from torchao.quantization import quantize_, int8_weight_only
22+
23+
"""
24+
Exporting Flow
25+
"""
26+
27+
28+
def main(args):
29+
builder_args = BuilderArgs.from_args(args)
30+
print(f"{builder_args=}")
31+
32+
quant_format = "int8_wo"
33+
# Quant option from cli, can be None
34+
model = _initialize_model(builder_args, args.quantize)
35+
if not args.quantize:
36+
# Not using quantization option from cli;
37+
# Use quantize_() to quantize the model instead.
38+
print("Quantizing model using torchao quantize_")
39+
quantize_(model, int8_weight_only())
40+
else:
41+
print(f"{args.quantize=}")
42+
43+
print(f"Model: {model}")
44+
45+
# Save model
46+
model_dir = os.path.dirname(builder_args.checkpoint_path)
47+
model_dir = Path(model_dir + "-" + quant_format)
48+
try:
49+
os.mkdir(model_dir)
50+
except FileExistsError:
51+
pass
52+
dest = model_dir / "model.pth"
53+
state_dict = model.state_dict()
54+
print(f"{state_dict.keys()=}")
55+
56+
print(f"Saving checkpoint to {dest}. This may take a while.")
57+
torch.save(state_dict, dest)
58+
print("Done.")

0 commit comments

Comments
 (0)