@@ -536,6 +536,15 @@ def _load_model(builder_args: BuilderArgs) -> Model:
536536 model = _load_model_default (builder_args )
537537 # model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims)
538538
539+ if builder_args .dso_path or builder_args .aoti_package_path :
540+ # AOTI-compoiled model will load its own weights.
541+ # Release weights here to avoid OOM
542+ import gc
543+ if hasattr (model , "model" ):
544+ model .model = None
545+ gc .collect ()
546+ torch .cuda .empty_cache ()
547+
539548 model = model .to (device = builder_args .device , dtype = builder_args .precision )
540549 return model .eval ()
541550
@@ -584,6 +593,12 @@ def _initialize_model(
584593 # attributes will NOT be seen on by AOTI-compiled forward
585594 # function, e.g. calling model.setup_cache will NOT touch
586595 # AOTI compiled and maintained model buffers such as kv_cache.
596+ # Using cpp runner to run AOTI compiled model is recommended.
597+
598+ def do_nothing (max_batch_size , max_seq_length ):
599+ pass
600+ model .setup_caches = do_nothing
601+
587602 model .forward = torch ._export .aot_load (
588603 str (builder_args .dso_path .absolute ()), builder_args .device
589604 )
@@ -617,6 +632,11 @@ def _initialize_model(
617632 aoti_compiled_model = load_package (
618633 str (builder_args .aoti_package_path .absolute ())
619634 )
635+
636+ def do_nothing (max_batch_size , max_seq_length ):
637+ pass
638+ model .setup_caches = do_nothing
639+
620640 model .forward = aoti_compiled_model
621641 metadata = aoti_compiled_model .get_metadata ()
622642 builder_args .device = metadata ["AOTI_DEVICE_KEY" ]
0 commit comments