@@ -2059,6 +2059,7 @@ def test_torch_compile_recompilation_and_graph_break(self):
20592059 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
20602060
20612061 model = self .model_class (** init_dict ).to (torch_device )
2062+ model .eval ()
20622063 model = torch .compile (model , fullgraph = True )
20632064
20642065 with (
@@ -2076,6 +2077,7 @@ def test_torch_compile_repeated_blocks(self):
20762077 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
20772078
20782079 model = self .model_class (** init_dict ).to (torch_device )
2080+ model .eval ()
20792081 model .compile_repeated_blocks (fullgraph = True )
20802082
20812083 recompile_limit = 1
@@ -2098,7 +2100,6 @@ def test_compile_with_group_offloading(self):
20982100
20992101 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
21002102 model = self .model_class (** init_dict )
2101-
21022103 model .eval ()
21032104 # TODO: Can test for other group offloading kwargs later if needed.
21042105 group_offload_kwargs = {
@@ -2111,25 +2112,46 @@ def test_compile_with_group_offloading(self):
21112112 }
21122113 model .enable_group_offload (** group_offload_kwargs )
21132114 model .compile ()
2115+
21142116 with torch .no_grad ():
21152117 _ = model (** inputs_dict )
21162118 _ = model (** inputs_dict )
21172119
2118- @require_torch_version_greater ("2.7.1" )
21192120 def test_compile_on_different_shapes (self ):
21202121 if self .different_shapes_for_compilation is None :
21212122 pytest .skip (f"Skipping as `different_shapes_for_compilation` is not set for { self .__class__ .__name__ } ." )
21222123 torch .fx .experimental ._config .use_duck_shape = False
21232124
21242125 init_dict , _ = self .prepare_init_args_and_inputs_for_common ()
21252126 model = self .model_class (** init_dict ).to (torch_device )
2127+ model .eval ()
21262128 model = torch .compile (model , fullgraph = True , dynamic = True )
21272129
21282130 for height , width in self .different_shapes_for_compilation :
21292131 with torch ._dynamo .config .patch (error_on_recompile = True ), torch .no_grad ():
21302132 inputs_dict = self .prepare_dummy_input (height = height , width = width )
21312133 _ = model (** inputs_dict )
21322134
2135+ def test_compile_works_with_aot (self ):
2136+ from torch ._inductor .package import load_package
2137+
2138+ init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
2139+
2140+ model = self .model_class (** init_dict ).to (torch_device )
2141+ exported_model = torch .export .export (model , args = (), kwargs = inputs_dict )
2142+
2143+ with tempfile .TemporaryDirectory () as tmpdir :
2144+ package_path = os .path .join (tmpdir , f"{ self .model_class .__name__ } .pt2" )
2145+ _ = torch ._inductor .aoti_compile_and_package (exported_model , package_path = package_path )
2146+ assert os .path .exists (package_path )
2147+ loaded_binary = load_package (package_path , run_single_threaded = True )
2148+
2149+ model .forward = loaded_binary
2150+
2151+ with torch .no_grad ():
2152+ _ = model (** inputs_dict )
2153+ _ = model (** inputs_dict )
2154+
21332155
21342156@slow
21352157@require_torch_2
0 commit comments