Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 9a53478

Browse files
committed
[aoti] Add cpp packaging for aoti + loading in python
1 parent 9c47edc commit 9a53478

File tree

8 files changed

+140
-38
lines changed

8 files changed

+140
-38
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ __pycache__/
66
# C extensions
77
*.so
88

9+
.vscode
910
.model-artifacts/
1011
.venv
1112
.torchchat
@@ -24,3 +25,6 @@ system_info.txt
2425

2526
# intermediate system file
2627
.DS_Store
28+
checkpoints/
29+
exportedModels/
30+
cmake-out/

README.md

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,9 @@ that is then loaded for inference. This can be done with both Python and C++ env
260260

261261
The following example exports and executes the Llama3.1 8B Instruct
262262
model. The first command compiles and performs the actual export.
263-
```
264-
python3 torchchat.py export llama3.1 --output-dso-path exportedModels/llama3.1.so
263+
264+
```bash
265+
python3 torchchat.py export llama3.1 --output-aoti-package-path exportedModels/llama3_1_artifacts
265266
```
266267

267268
> [!NOTE]
@@ -275,7 +276,7 @@ case visit our [customization guide](docs/model_customization.md).
275276

276277
To run in a python enviroment, use the generate subcommand like before, but include the dso file.
277278

278-
```
279+
```bash
279280
python3 torchchat.py generate llama3.1 --dso-path exportedModels/llama3.1.so --prompt "Hello my name is"
280281
```
281282
**Note:** Depending on which accelerator is used to generate the .dso file, the command may need the device specified: `--device (cuda | cpu)`.
@@ -288,9 +289,14 @@ To run in a C++ enviroment, we need to build the runner binary.
288289
torchchat/utils/scripts/build_native.sh aoti
289290
```
290291

291-
Then run the compiled executable, with the exported DSO from earlier.
292+
To compile the AOTI generated artifacts into a `.so`:
293+
```bash
294+
make -C exportedModels/llama3_1_artifacts
295+
```
296+
297+
Then run the compiled executable, with the compiled DSO.
292298
```bash
293-
cmake-out/aoti_run exportedModels/llama3.1.so -z `python3 torchchat.py where llama3.1`/tokenizer.model -l 3 -i "Once upon a time"
299+
cmake-out/aoti_run exportedModels/llama3_1_artifacts/llama3_1_artifacts.so -z `python3 torchchat.py where llama3.1`/tokenizer.model -l 3 -i "Once upon a time"
294300
```
295301
**Note:** Depending on which accelerator is used to generate the .dso file, the runner may need the device specified: `-d (CUDA | CPU)`.
296302

torchchat/cli/builder.py

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class BuilderArgs:
5353
gguf_path: Optional[Union[Path, str]] = None
5454
gguf_kwargs: Optional[Dict[str, Any]] = None
5555
dso_path: Optional[Union[Path, str]] = None
56+
aoti_package_path: Optional[Union[Path, str]] = None
5657
pte_path: Optional[Union[Path, str]] = None
5758
device: Optional[str] = None
5859
precision: torch.dtype = torch.float32
@@ -72,28 +73,29 @@ def __post_init__(self):
7273
or (self.checkpoint_dir and self.checkpoint_dir.is_dir())
7374
or (self.gguf_path and self.gguf_path.is_file())
7475
or (self.dso_path and Path(self.dso_path).is_file())
76+
or (self.aoti_package_path and Path(self.aoti_package_path).is_file())
7577
or (self.pte_path and Path(self.pte_path).is_file())
7678
):
7779
raise RuntimeError(
7880
"need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path"
7981
)
8082

81-
if self.dso_path and self.pte_path:
82-
raise RuntimeError("specify either DSO path or PTE path, but not both")
83+
if self.pte_path and self.aoti_package_path:
84+
raise RuntimeError("specify either AOTI Package path or PTE path, but not more than one")
8385

84-
if self.checkpoint_path and (self.dso_path or self.pte_path):
86+
if self.checkpoint_path and (self.pte_path or self.aoti_package_path):
8587
print(
86-
"Warning: checkpoint path ignored because an exported DSO or PTE path specified"
88+
"Warning: checkpoint path ignored because an exported AOTI or PTE path specified"
8789
)
88-
if self.checkpoint_dir and (self.dso_path or self.pte_path):
90+
if self.checkpoint_dir and (self.pte_path or self.aoti_package_path):
8991
print(
90-
"Warning: checkpoint dir ignored because an exported DSO or PTE path specified"
92+
"Warning: checkpoint dir ignored because an exported AOTI or PTE path specified"
9193
)
92-
if self.gguf_path and (self.dso_path or self.pte_path):
94+
if self.gguf_path and (self.pte_path or self.aoti_package_path):
9395
print(
94-
"Warning: GGUF path ignored because an exported DSO or PTE path specified"
96+
"Warning: GGUF path ignored because an exported AOTI or PTE path specified"
9597
)
96-
if not (self.dso_path) and not (self.pte_path):
98+
if not (self.dso_path) and not (self.aoti_package_path):
9799
self.prefill_possible = True
98100

99101
@classmethod
@@ -123,6 +125,7 @@ def from_args(cls, args): # -> BuilderArgs:
123125

124126
dso_path = getattr(args, "dso_path", None)
125127
pte_path = getattr(args, "pte_path", None)
128+
aoti_package_path = getattr(args, "aoti_package_path", None)
126129

127130
is_chat_model = False
128131
if args.is_chat_model:
@@ -133,6 +136,7 @@ def from_args(cls, args): # -> BuilderArgs:
133136
checkpoint_dir,
134137
dso_path,
135138
pte_path,
139+
aoti_package_path,
136140
args.gguf_path,
137141
]:
138142
if path is not None:
@@ -148,6 +152,7 @@ def from_args(cls, args): # -> BuilderArgs:
148152

149153

150154
output_pte_path = getattr(args, "output_pte_path", None)
155+
output_aoti_package_path = getattr(args, "output_aoti_package_path", None)
151156
output_dso_path = getattr(args, "output_dso_path", None)
152157
if output_pte_path and args.dtype.startswith("fast"):
153158
if args.dtype == "fast":
@@ -169,10 +174,11 @@ def from_args(cls, args): # -> BuilderArgs:
169174
gguf_path=args.gguf_path,
170175
gguf_kwargs=None,
171176
dso_path=dso_path,
177+
aoti_package_path=aoti_package_path,
172178
pte_path=pte_path,
173179
device=args.device,
174180
precision=dtype,
175-
setup_caches=(output_dso_path or output_pte_path),
181+
setup_caches=(output_dso_path or output_pte_path or output_aoti_package_path),
176182
use_distributed=args.distributed,
177183
is_chat_model=is_chat_model,
178184
dynamic_shapes=getattr(args, "dynamic_shapes", False),
@@ -187,6 +193,7 @@ def from_speculative_args(cls, args): # -> BuilderArgs:
187193
speculative_builder_args.checkpoint_path = args.draft_checkpoint_path
188194
speculative_builder_args.gguf_path = None
189195
speculative_builder_args.dso_path = None
196+
speculative_builder_args.aoti_package_path = None
190197
speculative_builder_args.pte_path = None
191198
return speculative_builder_args
192199

@@ -466,11 +473,12 @@ def _initialize_model(
466473
):
467474
print("Loading model...")
468475

469-
if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path):
476+
if builder_args.gguf_path and (builder_args.dso_path or builder_args.aoti_package_path or builder_args.pte_path):
470477
print("Setting gguf_kwargs for generate.")
471478
is_dso = builder_args.dso_path is not None
479+
is_aoti_package = builder_args.aoti_package_path is not None
472480
is_pte = builder_args.pte_path is not None
473-
assert not (is_dso and is_pte)
481+
assert not (is_dso and is_aoti_package and is_pte)
474482
assert builder_args.gguf_kwargs is None
475483
# TODO: make GGUF load independent of backend
476484
# currently not working because AVX int_mm broken
@@ -504,6 +512,36 @@ def _initialize_model(
504512
)
505513
except:
506514
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}")
515+
516+
elif builder_args.aoti_package_path:
517+
if not is_cuda_or_cpu_device(builder_args.device):
518+
print(
519+
f"Cannot load specified PT2 to {builder_args.device}. Attempting to load model to CPU instead"
520+
)
521+
builder_args.device = "cpu"
522+
523+
# assert (
524+
# quantize is None or quantize == "{ }"
525+
# ), "quantize not valid for exported PT2 model. Specify quantization during export."
526+
527+
with measure_time("Time to load model: {time:.02f} seconds"):
528+
model = _load_model(builder_args, only_config=True)
529+
device_sync(device=builder_args.device)
530+
531+
try:
532+
# Replace model forward with the AOT-compiled forward
533+
# This is a hacky way to quickly demo AOTI's capability.
534+
# model is still a Python object, and any mutation to its
535+
# attributes will NOT be seen on by AOTI-compiled forward
536+
# function, e.g. calling model.setup_cache will NOT touch
537+
# AOTI compiled and maintained model buffers such as kv_cache.
538+
from torch._inductor.package import load_package
539+
model.forward = load_package(
540+
str(builder_args.aoti_package_path.absolute()), builder_args.device
541+
)
542+
except:
543+
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.aoti_package_path}")
544+
507545
elif builder_args.pte_path:
508546
if not is_cpu_device(builder_args.device):
509547
print(

torchchat/cli/cli.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,12 @@ def _add_export_output_path_args(parser) -> None:
191191
default=None,
192192
help="Output to the specified AOT Inductor .dso model file",
193193
)
194+
output_path_parser.add_argument(
195+
"--output-aoti-package-path",
196+
type=str,
197+
default=None,
198+
help="Output directory for AOTInductor compiled artifacts",
199+
)
194200

195201

196202
def _add_export_args(parser) -> None:
@@ -220,6 +226,12 @@ def _add_exported_input_path_args(parser) -> None:
220226
default=None,
221227
help="Use the specified AOT Inductor .dso model file",
222228
)
229+
exclusive_parser.add_argument(
230+
"--aoti-package-path",
231+
type=Path,
232+
default=None,
233+
help="Use the specified directory containing AOT Inductor compiled files",
234+
)
223235
exclusive_parser.add_argument(
224236
"--pte-path",
225237
type=Path,

torchchat/export.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@
3535
def export_for_server(
3636
model: nn.Module,
3737
device: Optional[str] = "cpu",
38-
output_path: str = "model.dso",
38+
output_path: str = "model.pt2",
3939
dynamic_shapes: bool = False,
40+
package: bool = True,
41+
model_key: str = "",
4042
) -> str:
4143
"""
4244
Export the model using AOT Compile to get a .dso for server use cases.
@@ -65,14 +67,17 @@ def export_for_server(
6567
dynamic_shapes = None
6668

6769
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
68-
so = torch._export.aot_compile(
70+
path = torch._export.aot_compile(
6971
model,
7072
args=input,
71-
options={"aot_inductor.output_path": output_path},
73+
options={
74+
"aot_inductor.output_path": output_path,
75+
"aot_inductor.package": package,
76+
},
7277
dynamic_shapes=dynamic_shapes,
7378
)
74-
print(f"The generated DSO model can be found at: {so}")
75-
return so
79+
print(f"The generated DSO model can be found at: {path}")
80+
return path
7681

7782

7883
"""
@@ -335,14 +340,16 @@ def main(args):
335340

336341
print(f"Using device={builder_args.device}")
337342
set_precision(builder_args.precision)
338-
set_backend(dso=args.output_dso_path, pte=args.output_pte_path)
343+
set_backend(dso=args.output_dso_path, pte=args.output_pte_path, aoti_package=args.output_aoti_package_path)
339344

340345
builder_args.dso_path = None
341346
builder_args.pte_path = None
347+
builder_args.aoti_package_path = None
342348
builder_args.setup_caches = True
343349

344350
output_pte_path = args.output_pte_path
345351
output_dso_path = args.output_dso_path
352+
output_aoti_package_path = args.output_aoti_package_path
346353

347354
if output_pte_path and builder_args.device != "cpu":
348355
print(
@@ -380,6 +387,7 @@ def main(args):
380387
)
381388
model_to_pte = model
382389
model_to_dso = model
390+
model_to_aoti_package = model
383391
else:
384392
if output_pte_path:
385393
_set_gguf_kwargs(builder_args, is_et=True, context="export")
@@ -389,13 +397,14 @@ def main(args):
389397
)
390398
_unset_gguf_kwargs(builder_args)
391399

392-
if output_dso_path:
400+
if output_dso_path or output_aoti_package_path:
393401
_set_gguf_kwargs(builder_args, is_et=False, context="export")
394-
model_to_dso = _initialize_model(
402+
model_to_aoti_package = _initialize_model(
395403
builder_args,
396404
quantize,
397405
support_tensor_subclass=False,
398406
)
407+
model_to_dso = model_to_aoti_package
399408
_unset_gguf_kwargs(builder_args)
400409

401410
with torch.no_grad():
@@ -409,6 +418,7 @@ def main(args):
409418
"Export with executorch requested but ExecuTorch could not be loaded"
410419
)
411420
print(executorch_exception)
421+
412422
if output_dso_path:
413423
output_dso_path = str(os.path.abspath(output_dso_path))
414424
print(f"Exporting model using AOT Inductor to {output_dso_path}")
@@ -417,4 +427,17 @@ def main(args):
417427
builder_args.device,
418428
output_dso_path,
419429
builder_args.dynamic_shapes,
430+
package=False,
431+
)
432+
433+
if output_aoti_package_path:
434+
output_aoti_package_path = str(os.path.abspath(output_aoti_package_path))
435+
print(f"Exporting model using AOT Inductor to {output_aoti_package_path}")
436+
export_for_server(
437+
model_to_aoti_package,
438+
builder_args.device,
439+
output_aoti_package_path,
440+
builder_args.dynamic_shapes,
441+
package=True,
442+
model_key=builder_args.params_table,
420443
)

torchchat/generate.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ def validate_build(
133133
reason = "model compilation for prefill"
134134
if self.compile:
135135
reason = "model compilation"
136-
if builder_args.dso_path:
137-
model_type = "DSO"
136+
if builder_args.aoti_package_path:
137+
model_type = "PT2"
138138
if builder_args.pte_path:
139139
model_type = "PTE"
140140
if model_type and reason:
@@ -146,7 +146,10 @@ def validate_build(
146146
def from_args(cls, args):
147147
dso_path = getattr(args, "dso_path", None)
148148
pte_path = getattr(args, "pte_path", None)
149-
sequential_prefill = args.sequential_prefill or bool(dso_path) or bool(pte_path)
149+
aoti_package_path = getattr(args, "aoti_package_path", None)
150+
sequential_prefill = (
151+
args.sequential_prefill or bool(aoti_package_path) or bool(pte_path)
152+
)
150153

151154
return cls(
152155
prompt=getattr(args, "prompt", ""),
@@ -948,3 +951,13 @@ def main(args):
948951
torch.cuda.reset_peak_memory_stats()
949952
for _ in gen.chat(generator_args):
950953
pass
954+
955+
956+
if __name__ == "__main__":
957+
parser = argparse.ArgumentParser(description="torchchat generate CLI")
958+
verb = "generate"
959+
add_arguments_for_verb(parser, verb)
960+
args = parser.parse_args()
961+
check_args(args, verb)
962+
args = arg_init(args)
963+
main(args)

torchchat/usages/eval.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def main(args) -> None:
260260

261261
if compile:
262262
assert not (
263-
builder_args.dso_path or builder_args.pte_path
263+
builder_args.dso_path or builder_args.pte_path or builder_args.aoti_package_path
264264
), "cannot compile exported model"
265265
model_forward = torch.compile(
266266
model_forward, mode="reduce-overhead", dynamic=True, fullgraph=True
@@ -288,6 +288,8 @@ def main(args) -> None:
288288
)
289289
if builder_args.dso_path:
290290
print(f"For model {builder_args.dso_path}")
291+
if builder_args.aoti_package_path:
292+
print(f"For model {builder_args.aoti_package_path}")
291293
elif builder_args.pte_path:
292294
print(f"For model {builder_args.pte_path}")
293295
elif builder_args.checkpoint_path:

0 commit comments

Comments
 (0)