Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 9671810

Browse files
committed
Refactor the code
1 parent e3acb5c commit 9671810

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

torchchat/cli/builder.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,15 @@ def _load_model(builder_args: BuilderArgs) -> Model:
510510
model = _load_model_default(builder_args)
511511
# model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims)
512512

513+
if builder_args.dso_path or builder_args.aoti_package_path:
514+
# AOTI-compoiled model will load its own weights.
515+
# Release weights here to avoid OOM
516+
import gc
517+
if hasattr(model, "model"):
518+
model.model = None
519+
gc.collect()
520+
torch.cuda.empty_cache()
521+
513522
model = model.to(device=builder_args.device, dtype=builder_args.precision)
514523
return model.eval()
515524

@@ -558,6 +567,12 @@ def _initialize_model(
558567
# attributes will NOT be seen on by AOTI-compiled forward
559568
# function, e.g. calling model.setup_cache will NOT touch
560569
# AOTI compiled and maintained model buffers such as kv_cache.
570+
# Using cpp runner to run AOTI compiled model is recommended.
571+
572+
def do_nothing(max_batch_size, max_seq_length):
573+
pass
574+
model.setup_caches = do_nothing
575+
561576
model.forward = torch._export.aot_load(
562577
str(builder_args.dso_path.absolute()), builder_args.device
563578
)

0 commit comments

Comments
 (0)