diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index a7a22a1e8..fb2bfb299 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -536,6 +536,15 @@ def _load_model(builder_args: BuilderArgs) -> Model: model = _load_model_default(builder_args) # model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims) + if builder_args.dso_path or builder_args.aoti_package_path: + # AOTI-compoiled model will load its own weights. + # Release weights here to avoid OOM + import gc + if hasattr(model, "model"): + model.model = None + gc.collect() + torch.cuda.empty_cache() + model = model.to(device=builder_args.device, dtype=builder_args.precision) return model.eval() @@ -584,6 +593,12 @@ def _initialize_model( # attributes will NOT be seen on by AOTI-compiled forward # function, e.g. calling model.setup_cache will NOT touch # AOTI compiled and maintained model buffers such as kv_cache. + # Using cpp runner to run AOTI compiled model is recommended. + + def do_nothing(max_batch_size, max_seq_length): + pass + model.setup_caches = do_nothing + model.forward = torch._export.aot_load( str(builder_args.dso_path.absolute()), builder_args.device ) @@ -617,6 +632,11 @@ def _initialize_model( aoti_compiled_model = load_package( str(builder_args.aoti_package_path.absolute()) ) + + def do_nothing(max_batch_size, max_seq_length): + pass + model.setup_caches = do_nothing + model.forward = aoti_compiled_model metadata = aoti_compiled_model.get_metadata() builder_args.device = metadata["AOTI_DEVICE_KEY"]