Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 35db1f9

Browse files
committed
Use package_aoti API
1 parent f275b04 commit 35db1f9

File tree

3 files changed

+29
-28
lines changed

3 files changed

+29
-28
lines changed

.ci/scripts/validate.sh

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -133,60 +133,60 @@ function generate_aoti_model_output() {
133133
echo "******************************************"
134134
echo "************** non-quantized *************"
135135
echo "******************************************"
136-
python3 -W ignore torchchat.py export --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path "${MODEL_DIR}/${MODEL_NAME}.so" --device "$TARGET_DEVICE" || exit 1
137-
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --dso-path "$MODEL_DIR/${MODEL_NAME}.so" --prompt "$PROMPT" --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
136+
python3 -W ignore torchchat.py export --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path "${MODEL_DIR}/${MODEL_NAME}.pt2" --device "$TARGET_DEVICE" || exit 1
137+
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --dso-path "$MODEL_DIR/${MODEL_NAME}.pt2" --prompt "$PROMPT" --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
138138
.ci/scripts/check_gibberish "$MODEL_DIR/output_aoti"
139139

140140
echo "******************************************"
141141
echo "******* Emb: channel-wise quantized ******"
142142
echo "******************************************"
143-
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
144-
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
143+
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" || exit 1
144+
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
145145
.ci/scripts/check_gibberish "$MODEL_DIR/output_aoti"
146146

147147
echo "******************************************"
148148
echo "******** Emb: group-wise quantized *******"
149149
echo "******************************************"
150-
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
151-
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
150+
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" || exit 1
151+
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
152152
.ci/scripts/check_gibberish "$MODEL_DIR/output_aoti"
153153

154154
echo "***********************************************"
155155
echo "******* Emb: 4bit channel-wise quantized ******"
156156
echo "***********************************************"
157-
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 4, "groupsize": 0, "packed": "True"}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
158-
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
157+
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 4, "groupsize": 0, "packed": "True"}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" || exit 1
158+
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
159159
.ci/scripts/check_gibberish "$MODEL_DIR/output_aoti"
160160

161161
echo "***********************************************"
162162
echo "******** Emb: 4bit group-wise quantized *******"
163163
echo "***********************************************"
164-
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 4, "groupsize": 8, "packed": "True"}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
165-
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
164+
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 4, "groupsize": 8, "packed": "True"}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" || exit 1
165+
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
166166
.ci/scripts/check_gibberish "$MODEL_DIR/output_aoti"
167167

168168
if [ "${EXCLUDE_INT8_QUANT:-false}" == false ]; then
169169
echo "******************************************"
170170
echo "******* INT8 channel-wise quantized ******"
171171
echo "******************************************"
172-
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
173-
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
172+
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" || exit 1
173+
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
174174
.ci/scripts/check_gibberish "$MODEL_DIR/output_aoti"
175175

176176
echo "******************************************"
177177
echo "******** INT8 group-wise quantized *******"
178178
echo "******************************************"
179-
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
180-
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
179+
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" || exit 1
180+
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
181181
.ci/scripts/check_gibberish "$MODEL_DIR/output_aoti"
182182
fi
183183
echo "******************************************"
184184
echo "******** INT4 group-wise quantized *******"
185185
echo "******************************************"
186186
if [[ "$TARGET_DEVICE" != "cuda" || "$DTYPE" == "bfloat16" ]]; then
187187
# For CUDA, only bfloat16 makes sense for int4 mm kernel
188-
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
189-
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
188+
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" || exit 1
189+
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
190190
.ci/scripts/check_gibberish "$MODEL_DIR/output_aoti"
191191
fi
192192
done
@@ -285,8 +285,8 @@ function eval_model_sanity_check() {
285285
echo "******** INT4 group-wise quantized (AOTI) *******"
286286
echo "*************************************************"
287287
if [ "$DTYPE" != "float16" ]; then
288-
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --dynamic-shapes --device "$TARGET_DEVICE" || exit 1
289-
python3 -W ignore torchchat.py eval --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/output_eval_aoti" || exit 1
288+
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --dynamic-shapes --device "$TARGET_DEVICE" || exit 1
289+
python3 -W ignore torchchat.py eval --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --dso-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/output_eval_aoti" || exit 1
290290
cat "$MODEL_DIR/output_eval_aoti"
291291
fi;
292292
fi;

.github/workflows/pull.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -378,8 +378,8 @@ jobs:
378378
379379
echo "::group::Run inference with quantize file"
380380
if [ $(uname -s) == Darwin ]; then
381-
python3 torchchat.py export --output-dso-path /tmp/model.so --quantize torchchat/quant_config/cuda.json --checkpoint "./checkpoints/${REPO_NAME}/model.pth"
382-
python3 torchchat.py generate --dso-path /tmp/model.so --checkpoint "./checkpoints/${REPO_NAME}/model.pth"~
381+
python3 torchchat.py export --output-dso-path /tmp/model.pt2 --quantize torchchat/quant_config/cuda.json --checkpoint "./checkpoints/${REPO_NAME}/model.pth"
382+
python3 torchchat.py generate --dso-path /tmp/model.pt2 --checkpoint "./checkpoints/${REPO_NAME}/model.pth"~
383383
fi
384384
echo "::endgroup::"
385385
@@ -1016,8 +1016,8 @@ jobs:
10161016
10171017
for dtype in fp32 fp16 bf16 fast fast16; do
10181018
echo "Running export + runner with dtype=$dtype"
1019-
python torchchat.py export --checkpoint-path ${MODEL_DIR}/stories15M.pt --dtype $dtype --output-dso-path /tmp/model.so
1020-
./cmake-out/aoti_run /tmp/model.so -z ${MODEL_DIR}/tokenizer.model -i "${PROMPT}"
1019+
python torchchat.py export --checkpoint-path ${MODEL_DIR}/stories15M.pt --dtype $dtype --output-dso-path /tmp/model.pt2
1020+
./cmake-out/aoti_run /tmp/model.pt2 -z ${MODEL_DIR}/tokenizer.model -i "${PROMPT}"
10211021
done
10221022
10231023
echo "Tests complete."
@@ -1111,8 +1111,8 @@ jobs:
11111111
python torchchat.py export stories110M --output-pte-path ./model.pte --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
11121112
./cmake-out/et_run ./model.pte -z ./tokenizer.model -t 0 -i "${PRMT}"
11131113
echo "Export and run AOTI (C++ runner)"
1114-
python torchchat.py export stories110M --output-dso-path ./model.so --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
1115-
./cmake-out/aoti_run ./model.so -z ./tokenizer.model -t 0 -i "${PRMT}"
1114+
python torchchat.py export stories110M --output-dso-path ./model.pt2 --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
1115+
./cmake-out/aoti_run ./model.pt2 -z ./tokenizer.model -t 0 -i "${PRMT}"
11161116
echo "Generate AOTI"
1117-
python torchchat.py generate stories110M --dso-path ./model.so --prompt "${PRMT}"
1117+
python torchchat.py generate stories110M --dso-path ./model.pt2 --prompt "${PRMT}"
11181118
echo "Tests complete."

torchchat/export.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@
3535
def export_for_server(
3636
model: nn.Module,
3737
device: Optional[str] = "cpu",
38-
output_path: str = "model.dso",
38+
output_path: str = "model.pt2",
3939
dynamic_shapes: bool = False,
4040
) -> str:
4141
"""
42-
Export the model using AOT Compile to get a .dso for server use cases.
42+
Export the model using AOT Compile to get a .pt2 for server use cases.
4343
4444
Args:
4545
model: The model to be exported.
@@ -71,7 +71,8 @@ def export_for_server(
7171
options={"aot_inductor.output_path": output_path},
7272
dynamic_shapes=dynamic_shapes,
7373
)
74-
print(f"The generated DSO model can be found at: {so}")
74+
package_aoti(output_path, so)
75+
print(f"The generated DSO model can be found at: {output_path}")
7576
return so
7677

7778

0 commit comments

Comments
 (0)