Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit abf0679

Browse files
authored
Merge branch 'main' into refactor/dist_run
2 parents 80f8138 + 7fe2c86 commit abf0679

File tree

13 files changed

+201
-151
lines changed

13 files changed

+201
-151
lines changed

.ci/scripts/run-docs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ fi
77

88
if [ "$1" == "readme" ]; then
99
echo "::group::Create script to run README"
10-
python3 torchchat/utils/scripts/updown.py --create-sections --file README.md --replace 'llama3:stories15M,-l 3:-l 2' --suppress huggingface-cli,HF_TOKEN > ./run-readme.sh
10+
python3 torchchat/utils/scripts/updown.py --create-sections --file README.md --replace 'llama3.1:stories15M,-l 3:-l 2' --suppress huggingface-cli,HF_TOKEN > ./run-readme.sh
1111
# for good measure, if something happened to updown processor,
1212
# and it did not error out, fail with an exit 1
1313
echo "exit 1" >> ./run-readme.sh

.github/workflows/pull.yml

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,32 +1092,11 @@ jobs:
10921092
id: install-torchao-ops
10931093
run: |
10941094
bash torchchat/utils/scripts/build_torchao_ops.sh
1095-
- name: Set git shas
1096-
id: setup-hash
1097-
run: |
1098-
export TORCHCHAT_ROOT=${PWD}
1099-
echo "et-git-hash=$(cat ${TORCHCHAT_ROOT}/install/.pins/et-pin.txt)" >> "$GITHUB_ENV"
1100-
- name: Load or install ET
1101-
id: install-et
1102-
uses: actions/cache@v4
1103-
with:
1104-
path: |
1105-
./et-build
1106-
./torchchat/utils/scripts/install_et.sh
1107-
key: et-build-${{runner.os}}-${{runner.arch}}-${{env.et-git-hash}}-${{ hashFiles('**/install_et.sh') }}
1108-
- if: ${{ steps.install-et.outputs.cache-hit != 'true' }}
1109-
continue-on-error: true
1095+
- name: Install ET
11101096
run: |
11111097
echo "Installing ExecuTorch"
1098+
export TORCHCHAT_ROOT=${PWD}
11121099
bash torchchat/utils/scripts/install_et.sh
1113-
- name: Install ExecuTorch python
1114-
run: |
1115-
echo "Install ExecuTorch python"
1116-
export TORCHCHAT_ROOT=$PWD
1117-
export ET_BUILD_DIR="et-build"
1118-
ENABLE_ET_PYBIND="${1:-true}"
1119-
source "torchchat/utils/scripts/install_utils.sh"
1120-
install_executorch_python_libs $ENABLE_ET_PYBIND
11211100
- name: Install runner
11221101
run: |
11231102
echo "Installing runner"
@@ -1132,14 +1111,14 @@ jobs:
11321111
wget -O ./tokenizer.model https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
11331112
export PRMT="Once upon a time in a land far away"
11341113
echo "Generate eager"
1135-
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}'
1114+
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
11361115
echo "Generate compile"
1137-
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --compile
1116+
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --compile
11381117
echo "Export and run ET (C++ runner)"
1139-
python torchchat.py export stories110M --output-pte-path ./model.pte --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}'
1118+
python torchchat.py export stories110M --output-pte-path ./model.pte --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
11401119
./cmake-out/et_run ./model.pte -z ./tokenizer.model -t 0 -i "${PRMT}"
11411120
echo "Export and run AOTI (C++ runner)"
1142-
python torchchat.py export stories110M --output-dso-path ./model.so --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}'
1121+
python torchchat.py export stories110M --output-dso-path ./model.so --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
11431122
./cmake-out/aoti_run ./model.so -z ./tokenizer.model -t 0 -i "${PRMT}"
11441123
echo "Generate AOTI"
11451124
python torchchat.py generate stories110M --dso-path ./model.so --prompt "${PRMT}"

README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ python3 torchchat.py download llama3.1
171171
<summary>Additional Model Inventory Management Commands</summary>
172172

173173
### Where
174-
This subcommand shows location of a particular model.
174+
This subcommand shows the location of a particular model.
175175
```bash
176176
python3 torchchat.py where llama3.1
177177
```
@@ -216,7 +216,6 @@ This mode generates text based on an input prompt.
216216
python3 torchchat.py generate llama3.1 --prompt "write me a story about a boy and his bear"
217217
```
218218

219-
[skip default]: end
220219

221220
### Server
222221
This mode exposes a REST API for interacting with a model.
@@ -286,14 +285,16 @@ First, follow the steps in the Server section above to start a local server. The
286285
streamlit run torchchat/usages/browser.py
287286
```
288287

288+
[skip default]: end
289+
289290
Use the "Max Response Tokens" slider to limit the maximum number of tokens generated by the model for each response. Click the "Reset Chat" button to remove the message history and start a fresh chat.
290291

291292

292293
## Desktop/Server Execution
293294

294295
### AOTI (AOT Inductor)
295296
[AOTI](https://pytorch.org/blog/pytorch2-2/) compiles models before execution for faster inference. The process creates a [DSO](https://en.wikipedia.org/wiki/Shared_library) model (represented by a file with extension `.so`)
296-
that is then loaded for inference. This can be done with both Python and C++ enviroments.
297+
that is then loaded for inference. This can be done with both Python and C++ environments.
297298

298299
The following example exports and executes the Llama3.1 8B Instruct
299300
model. The first command compiles and performs the actual export.
@@ -308,9 +309,9 @@ python3 torchchat.py export llama3.1 --output-dso-path exportedModels/llama3.1.s
308309
For more details on quantization and what settings to use for your use
309310
case visit our [customization guide](docs/model_customization.md).
310311

311-
### Run in a Python Enviroment
312+
### Run in a Python Environment
312313

313-
To run in a python enviroment, use the generate subcommand like before, but include the dso file.
314+
To run in a python environment, use the generate subcommand like before, but include the dso file.
314315

315316
```
316317
python3 torchchat.py generate llama3.1 --dso-path exportedModels/llama3.1.so --prompt "Hello my name is"
@@ -377,7 +378,7 @@ While ExecuTorch does not focus on desktop inference, it is capable
377378
of doing so. This is handy for testing out PTE
378379
models without sending them to a physical device.
379380

380-
Specifically there are 2 ways of doing so: Pure Python and via a Runner
381+
Specifically, there are 2 ways of doing so: Pure Python and via a Runner
381382

382383
<details>
383384
<summary>Deploying via Python</summary>
@@ -501,7 +502,7 @@ The following assumes you've completed the steps for [Setting up ExecuTorch](#se
501502
and use [this script](https://github.com/pytorch/executorch/blob/main/build/build_android_llm_demo.sh) to build the AAR library.
502503
503504
<p align="center">
504-
<img src="https://pytorch.org/executorch/main/_static/img/android_llama_app.png" width="600" alt="Android app running a LlaMA model">
505+
<img src="https://pytorch.org/executorch/main/_static/img/chat.png" width="600" alt="Android app running a LlaMA model">
505506
</p>
506507
507508

assets/view.jpg

93.3 KB
Loading

docs/quantization.md

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,22 +121,29 @@ python3 torchchat.py generate llama3 --pte-path llama3.pte --prompt "Hello my n
121121
## Experimental TorchAO lowbit kernels
122122

123123
### Use
124-
The quantization scheme a8wxdq dynamically quantizes activations to 8 bits, and quantizes the weights in a groupwise manner with a specified bitwidth and groupsize.
124+
125+
#### linear:a8wxdq
126+
The quantization scheme linear:a8wxdq dynamically quantizes activations to 8 bits, and quantizes the weights in a groupwise manner with a specified bitwidth and groupsize.
125127
It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7), groupsize, and has_weight_zeros (true, false).
126128
The argument has_weight_zeros indicates whether the weights are quantized with scales only (has_weight_zeros: false) or with both scales and zeros (has_weight_zeros: true).
127129
Roughly speaking, {bitwidth: 4, groupsize: 32, has_weight_zeros: false} is similar to GGML's Q4_0 quantization scheme.
128130

129-
You should expect high performance on ARM CPU if bitwidth is 1, 2, 3, 4, or 5 and groupsize is divisible by 16. With other platforms and argument choices, a slow fallback kernel will be used. You will see warnings about this during quantization.
131+
You should expect high performance on ARM CPU if groupsize is divisible by 16. With other platforms and argument choices, a slow fallback kernel will be used. You will see warnings about this during quantization.
132+
133+
#### embedding:wx
134+
The quantization scheme embedding:wx quantizes embeddings in a groupwise manner with the specified bitwidth and groupsize. It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7) and groupsize. Unlike linear:a8wxdq, embedding:wx always quantizes with scales and zeros.
135+
136+
You should expect high performance on ARM CPU if groupsize is divisible by 32. With other platforms and argument choices, a slow fallback kernel will be used. You will see warnings about this during quantization.
130137

131138
### Setup
132-
To use a8wxdq, you must set up the torchao experimental kernels. These will only work on devices with ARM CPUs, for example on Mac computers with Apple Silicon.
139+
To use linear:a8wxdq and embedding:wx, you must set up the torchao experimental kernels. These will only work on devices with ARM CPUs, for example on Mac computers with Apple Silicon.
133140

134141
From the torchchat root directory, run
135142
```
136143
sh torchchat/utils/scripts/build_torchao_ops.sh
137144
```
138145

139-
This should take about 10 seconds to complete. Once finished, you can use a8wxdq in torchchat.
146+
This should take about 10 seconds to complete.
140147

141148
Note: if you want to use the new kernels in the AOTI and C++ runners, you must pass the flag link_torchao_ops when running the scripts the build the runners.
142149

@@ -156,17 +163,17 @@ Below we show how to use the new kernels. Except for ExecuTorch, you can specif
156163

157164
#### Eager mode
158165
```
159-
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --prompt "Once upon a time," --num-samples 5
166+
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --prompt "Once upon a time," --num-samples 5
160167
```
161168

162169
#### torch.compile
163170
```
164-
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --compile --prompt "Once upon a time," --num-samples 5
171+
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --compile --prompt "Once upon a time," --num-samples 5
165172
```
166173

167174
#### AOTI
168175
```
169-
OMP_NUM_THREADS=6 python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --output-dso llama3_1.so
176+
OMP_NUM_THREADS=6 python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --output-dso llama3_1.so
170177
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --dso-path llama3_1.so --prompt "Once upon a time," --num-samples 5
171178
```
172179

@@ -178,7 +185,7 @@ OMP_NUM_THREADS=6 ./cmake-out/aoti_run llama3_1.so -z $HOME/.torchchat/model-cac
178185

179186
#### ExecuTorch
180187
```
181-
python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --output-pte llama3_1.pte
188+
python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --output-pte llama3_1.pte
182189
```
183190

184191
Note: only the ExecuTorch C++ runner in torchchat when built using the instructions in the setup can run the exported *.pte file. It will not work with the `python torchchat.py generate` command.

install/.pins/torchao-pin.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
49b1fb61c8b8eceda755579a2fd92c756d822de2
1+
c8f1174a06dcc0102849c8348ca6573bde8847a9

torchchat/cli/builder.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -84,19 +84,16 @@ def __post_init__(self):
8484
if self.dso_path and self.pte_path:
8585
raise RuntimeError("specify either DSO path or PTE path, but not both")
8686

87-
if self.checkpoint_path and (self.dso_path or self.pte_path):
88-
print(
89-
"Warning: checkpoint path ignored because an exported DSO or PTE path specified"
90-
)
91-
if self.checkpoint_dir and (self.dso_path or self.pte_path):
92-
print(
93-
"Warning: checkpoint dir ignored because an exported DSO or PTE path specified"
94-
)
95-
if self.gguf_path and (self.dso_path or self.pte_path):
96-
print(
97-
"Warning: GGUF path ignored because an exported DSO or PTE path specified"
98-
)
99-
if not (self.dso_path) and not (self.pte_path):
87+
if self.dso_path or self.pte_path:
88+
ignored_params = [
89+
(self.checkpoint_path, "checkpoint path"),
90+
(self.checkpoint_dir, "checkpoint dir"),
91+
(self.gguf_path, "GGUF path"),
92+
]
93+
for param, param_msg in ignored_params:
94+
if param:
95+
print(f"Warning: {param_msg} ignored because an exported DSO or PTE path was specified")
96+
else:
10097
self.prefill_possible = True
10198

10299
@classmethod
@@ -458,7 +455,7 @@ def _maybe_init_distributed(
458455
return world_mesh, parallel_dims
459456

460457

461-
def _maybe_parellelize_model(
458+
def _maybe_parallelize_model(
462459
model: nn.Module,
463460
builder_args: BuilderArgs,
464461
world_mesh: DeviceMesh,
@@ -498,7 +495,7 @@ def _load_model(builder_args: BuilderArgs) -> Model:
498495
# model = _init_model_on_meta_device(builder_args)
499496
else:
500497
model = _load_model_default(builder_args)
501-
# model = _maybe_parellelize_model(model, builder_args, world_mesh, parallel_dims)
498+
model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims)
502499

503500
model = model.to(device=builder_args.device, dtype=builder_args.precision)
504501
return model.eval()

torchchat/cli/download.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
from pathlib import Path
1111
from typing import Optional
1212

13-
from torchchat.cli.convert_hf_checkpoint import convert_hf_checkpoint, convert_hf_checkpoint_to_tune
13+
from torchchat.cli.convert_hf_checkpoint import (
14+
convert_hf_checkpoint,
15+
convert_hf_checkpoint_to_tune,
16+
)
1417
from torchchat.model_config.model_config import (
1518
load_model_configs,
1619
ModelConfig,
@@ -57,7 +60,6 @@ def _download_hf_snapshot(
5760
snapshot_download(
5861
model_config.distribution_path,
5962
local_dir=artifact_dir,
60-
local_dir_use_symlinks=False,
6163
token=hf_token,
6264
ignore_patterns=ignore_patterns,
6365
)
@@ -77,9 +79,14 @@ def _download_hf_snapshot(
7779
raise e
7880

7981
# Convert the Multimodal Llama model to the torchtune format.
80-
if model_config.name in {"meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.2-11B-Vision"}:
82+
if model_config.name in {
83+
"meta-llama/Llama-3.2-11B-Vision-Instruct",
84+
"meta-llama/Llama-3.2-11B-Vision",
85+
}:
8186
print(f"Converting {model_config.name} to torchtune format...", file=sys.stderr)
82-
convert_hf_checkpoint_to_tune( model_dir=artifact_dir, model_name=model_config.name)
87+
convert_hf_checkpoint_to_tune(
88+
model_dir=artifact_dir, model_name=model_config.name
89+
)
8390

8491
else:
8592
# Convert the model to the torchchat format.

torchchat/generate.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,15 @@ def from_args(cls, args):
180180

181181
# Validate that all image prompts exist before expensive model load
182182
if image_prompts := getattr(args, "image_prompts", None):
183-
if not all(os.path.exists(image_prompt) for image_prompt in image_prompts):
184-
raise RuntimeError(f"Image prompt {image_prompt} does not exist")
183+
non_existent_image_prompts = [
184+
image_prompt
185+
for image_prompt in image_prompts
186+
if (not os.path.exists(image_prompt))
187+
]
188+
if len(non_existent_image_prompts):
189+
raise RuntimeError(
190+
f"Image prompt {non_existent_image_prompts} does not exist"
191+
)
185192

186193
return cls(
187194
prompt=getattr(args, "prompt", ""),

torchchat/utils/build_utils.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from __future__ import annotations
88

9+
from enum import Enum
910
import logging
1011
import os
1112
from pathlib import Path
@@ -78,36 +79,33 @@ def set_backend(dso, pte):
7879
active_builder_args_pte = pte
7980

8081

81-
def use_aoti_backend() -> bool:
82+
class _Backend(Enum):
83+
AOTI = 0,
84+
EXECUTORCH = 1
85+
86+
87+
def _active_backend() -> _Backend:
8288
global active_builder_args_dso
8389
global active_builder_args_pte
8490

8591
# eager == aoti, which is when backend has not been explicitly set
8692
if (not active_builder_args_dso) and not (active_builder_args_pte):
87-
return True
93+
return _Backend.AOTI
8894

8995
if active_builder_args_pte and active_builder_args_dso:
9096
raise RuntimeError(
9197
"code generation needs to choose different implementations for DSO and PTE path. Please only use one export option, and call export twice if necessary!"
9298
)
9399

94-
return bool(active_builder_args_dso)
100+
return _Backend.AOTI if active_builder_args_dso else _Backend.EXECUTORCH
95101

96102

97-
def use_et_backend() -> bool:
98-
global active_builder_args_dso
99-
global active_builder_args_pte
100-
101-
# eager == aoti, which is when backend has not been explicitly set
102-
if not (active_builder_args_pte or active_builder_args_dso):
103-
return False
103+
def use_aoti_backend() -> bool:
104+
return _active_backend() == _Backend.AOTI
104105

105-
if active_builder_args_pte and active_builder_args_dso:
106-
raise RuntimeError(
107-
"code generation needs to choose different implementations for DSO and PTE path. Please only use one export option, and call export twice if necessary!"
108-
)
109106

110-
return bool(active_builder_args_pte)
107+
def use_et_backend() -> bool:
108+
return _active_backend() == _Backend.EXECUTORCH
111109

112110

113111
##########################################################################
@@ -142,9 +140,9 @@ def name_to_dtype(name, device):
142140
return torch.float16
143141
return torch.bfloat16
144142

145-
if name in name_to_dtype_dict:
143+
try:
146144
return name_to_dtype_dict[name]
147-
else:
145+
except KeyError:
148146
raise RuntimeError(f"unsupported dtype name {name} specified")
149147

150148

@@ -212,10 +210,7 @@ def canonical_path(path):
212210

213211

214212
def state_dict_device(d, device="cpu") -> Dict:
215-
for key, weight in d.items():
216-
d[key] = weight.to(device=device)
217-
218-
return d
213+
return {key : weight.to(device=device) for (key, weight) in d.items()}
219214

220215

221216
#########################################################################
@@ -259,9 +254,9 @@ def get_device(device) -> str:
259254
return torch.device(device)
260255

261256

262-
def is_cuda_or_cpu_device(device) -> bool:
263-
return device == "" or str(device) == "cpu" or ("cuda" in str(device))
264-
265-
266257
def is_cpu_device(device) -> bool:
267258
return device == "" or str(device) == "cpu"
259+
260+
261+
def is_cuda_or_cpu_device(device) -> bool:
262+
return is_cpu_device(device) or ("cuda" in str(device))

0 commit comments

Comments
 (0)