1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515import gc
16- import unittest
16+ import inspect
1717
1818import torch
1919
2323
2424@require_torch_gpu
2525@slow
26- class QuantCompileTests ( unittest . TestCase ) :
26+ class QuantCompileTests :
2727 @property
2828 def quantization_config (self ):
2929 raise NotImplementedError (
@@ -50,30 +50,26 @@ def _init_pipeline(self, quantization_config, torch_dtype):
5050 )
5151 return pipe
5252
53- def _test_torch_compile (self , quantization_config , torch_dtype = torch .bfloat16 ):
54- pipe = self ._init_pipeline (quantization_config , torch_dtype ).to ("cuda" )
55- # import to ensure fullgraph True
53+ def _test_torch_compile (self , torch_dtype = torch .bfloat16 ):
54+ pipe = self ._init_pipeline (self . quantization_config , torch_dtype ).to ("cuda" )
55+ # `fullgraph=True` ensures no graph breaks
5656 pipe .transformer .compile (fullgraph = True )
5757
58- for _ in range (2 ):
59- # small resolutions to ensure speedy execution.
60- pipe ("a dog" , num_inference_steps = 3 , max_sequence_length = 16 , height = 256 , width = 256 )
58+ # small resolutions to ensure speedy execution.
59+ pipe ("a dog" , num_inference_steps = 2 , max_sequence_length = 16 , height = 256 , width = 256 )
6160
62- def _test_torch_compile_with_cpu_offload (self , quantization_config , torch_dtype = torch .bfloat16 ):
63- pipe = self ._init_pipeline (quantization_config , torch_dtype )
61+ def _test_torch_compile_with_cpu_offload (self , torch_dtype = torch .bfloat16 ):
62+ pipe = self ._init_pipeline (self . quantization_config , torch_dtype )
6463 pipe .enable_model_cpu_offload ()
6564 pipe .transformer .compile ()
6665
67- for _ in range (2 ):
68- # small resolutions to ensure speedy execution.
69- pipe ("a dog" , num_inference_steps = 3 , max_sequence_length = 16 , height = 256 , width = 256 )
66+ # small resolutions to ensure speedy execution.
67+ pipe ("a dog" , num_inference_steps = 2 , max_sequence_length = 16 , height = 256 , width = 256 )
7068
71- def _test_torch_compile_with_group_offload_leaf (
72- self , quantization_config , torch_dtype = torch .bfloat16 , * , use_stream : bool = False
73- ):
74- torch ._dynamo .config .cache_size_limit = 10000
69+ def _test_torch_compile_with_group_offload_leaf (self , torch_dtype = torch .bfloat16 , * , use_stream : bool = False ):
70+ torch ._dynamo .config .cache_size_limit = 1000
7571
76- pipe = self ._init_pipeline (quantization_config , torch_dtype )
72+ pipe = self ._init_pipeline (self . quantization_config , torch_dtype )
7773 group_offload_kwargs = {
7874 "onload_device" : torch .device ("cuda" ),
7975 "offload_device" : torch .device ("cpu" ),
@@ -87,6 +83,17 @@ def _test_torch_compile_with_group_offload_leaf(
8783 if torch .device (component .device ).type == "cpu" :
8884 component .to ("cuda" )
8985
90- for _ in range (2 ):
91- # small resolutions to ensure speedy execution.
92- pipe ("a dog" , num_inference_steps = 3 , max_sequence_length = 16 , height = 256 , width = 256 )
86+ # small resolutions to ensure speedy execution.
87+ pipe ("a dog" , num_inference_steps = 2 , max_sequence_length = 16 , height = 256 , width = 256 )
88+
89+ def test_torch_compile (self ):
90+ self ._test_torch_compile ()
91+
92+ def test_torch_compile_with_cpu_offload (self ):
93+ self ._test_torch_compile_with_cpu_offload ()
94+
95+ def test_torch_compile_with_group_offload_leaf (self , use_stream = False ):
96+ for cls in inspect .getmro (self .__class__ ):
97+ if "test_torch_compile_with_group_offload_leaf" in cls .__dict__ and cls is not QuantCompileTests :
98+ return
99+ self ._test_torch_compile_with_group_offload_leaf (use_stream = use_stream )
0 commit comments