Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 3b52e0e

Browse files
authored
Merge branch 'main' into main
2 parents f442e9e + 4a7dab8 commit 3b52e0e

File tree

3 files changed

+40
-12
lines changed

3 files changed

+40
-12
lines changed

torchchat/cli/builder.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,15 @@ def _load_model(builder_args: BuilderArgs) -> Model:
536536
model = _load_model_default(builder_args)
537537
# model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims)
538538

539+
if builder_args.dso_path or builder_args.aoti_package_path:
540+
# AOTI-compoiled model will load its own weights.
541+
# Release weights here to avoid OOM
542+
import gc
543+
if hasattr(model, "model"):
544+
model.model = None
545+
gc.collect()
546+
torch.cuda.empty_cache()
547+
539548
model = model.to(device=builder_args.device, dtype=builder_args.precision)
540549
return model.eval()
541550

@@ -584,6 +593,12 @@ def _initialize_model(
584593
# attributes will NOT be seen on by AOTI-compiled forward
585594
# function, e.g. calling model.setup_cache will NOT touch
586595
# AOTI compiled and maintained model buffers such as kv_cache.
596+
# Using cpp runner to run AOTI compiled model is recommended.
597+
598+
def do_nothing(max_batch_size, max_seq_length):
599+
pass
600+
model.setup_caches = do_nothing
601+
587602
model.forward = torch._export.aot_load(
588603
str(builder_args.dso_path.absolute()), builder_args.device
589604
)
@@ -617,6 +632,11 @@ def _initialize_model(
617632
aoti_compiled_model = load_package(
618633
str(builder_args.aoti_package_path.absolute())
619634
)
635+
636+
def do_nothing(max_batch_size, max_seq_length):
637+
pass
638+
model.setup_caches = do_nothing
639+
620640
model.forward = aoti_compiled_model
621641
metadata = aoti_compiled_model.get_metadata()
622642
builder_args.device = metadata["AOTI_DEVICE_KEY"]

torchchat/export.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,11 @@ def export_for_server(
7878
dynamic_shapes=dynamic_shapes,
7979
options=options,
8080
)
81-
81+
8282
if package:
8383
from torch._inductor.package import package_aoti
8484
path = package_aoti(output_path, path)
85-
85+
8686
print(f"The generated packaged model can be found at: {path}")
8787
return path
8888

@@ -382,7 +382,7 @@ def main(args):
382382

383383
if builder_args.max_seq_length is None:
384384
if (
385-
output_dso_path is not None
385+
(output_dso_path is not None or output_aoti_package_path is not None)
386386
and not builder_args.dynamic_shapes
387387
):
388388
print("Setting max_seq_length to 300 for DSO export.")

torchchat/utils/build_utils.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from enum import Enum
1313
from pathlib import Path
14-
from typing import Any, Callable, Dict, List, Tuple
14+
from typing import Any, Callable, Dict, List, Optional, Tuple
1515

1616
import torch
1717

@@ -77,31 +77,39 @@ def unpack_packed_weights(
7777
def set_backend(dso, pte, aoti_package):
7878
global active_builder_args_dso
7979
global active_builder_args_pte
80+
global active_builder_args_aoti_package
8081
active_builder_args_dso = dso
8182
active_builder_args_aoti_package = aoti_package
8283
active_builder_args_pte = pte
8384

8485

8586
class _Backend(Enum):
86-
AOTI = (0,)
87+
AOTI = 0
8788
EXECUTORCH = 1
8889

8990

90-
def _active_backend() -> _Backend:
91+
def _active_backend() -> Optional[_Backend]:
9192
global active_builder_args_dso
9293
global active_builder_args_aoti_package
9394
global active_builder_args_pte
9495

95-
# eager == aoti, which is when backend has not been explicitly set
96-
if (not active_builder_args_pte) and (not active_builder_args_aoti_package):
97-
return True
96+
args = (
97+
active_builder_args_dso,
98+
active_builder_args_pte,
99+
active_builder_args_aoti_package,
100+
)
101+
102+
# Return None, as default
103+
if not any(args):
104+
return None
98105

99-
if active_builder_args_pte and active_builder_args_aoti_package:
106+
# Catch more than one arg
107+
if sum(map(bool, args)) > 1:
100108
raise RuntimeError(
101-
"code generation needs to choose different implementations for AOTI and PTE path. Please only use one export option, and call export twice if necessary!"
109+
"Code generation needs to choose different implementations. Please only use one export option, and call export twice if necessary!"
102110
)
103111

104-
return _Backend.AOTI if active_builder_args_pte else _Backend.EXECUTORCH
112+
return _Backend.EXECUTORCH if active_builder_args_pte else _Backend.AOTI
105113

106114

107115
def use_aoti_backend() -> bool:

0 commit comments

Comments
 (0)