3636 nightly ,
3737 require_torch ,
3838 require_torch_gpu ,
39- require_torchao_version_greater ,
39+ require_torchao_version_greater_or_equal ,
4040 slow ,
4141 torch_device ,
4242)
@@ -74,13 +74,13 @@ def forward(self, input, *args, **kwargs):
7474
7575if is_torchao_available ():
7676 from torchao .dtypes import AffineQuantizedTensor
77- from torchao .dtypes .affine_quantized_tensor import TensorCoreTiledLayoutType
7877 from torchao .quantization .linear_activation_quantized_tensor import LinearActivationQuantizedTensor
78+ from torchao .utils import get_model_size_in_bytes
7979
8080
8181@require_torch
8282@require_torch_gpu
83- @require_torchao_version_greater ("0.6 .0" )
83+ @require_torchao_version_greater_or_equal ("0.7 .0" )
8484class TorchAoConfigTest (unittest .TestCase ):
8585 def test_to_dict (self ):
8686 """
@@ -125,7 +125,7 @@ def test_repr(self):
125125# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
126126@require_torch
127127@require_torch_gpu
128- @require_torchao_version_greater ("0.6 .0" )
128+ @require_torchao_version_greater_or_equal ("0.7 .0" )
129129class TorchAoTest (unittest .TestCase ):
130130 def tearDown (self ):
131131 gc .collect ()
@@ -139,11 +139,13 @@ def get_dummy_components(self, quantization_config: TorchAoConfig):
139139 quantization_config = quantization_config ,
140140 torch_dtype = torch .bfloat16 ,
141141 )
142- text_encoder = CLIPTextModel .from_pretrained (model_id , subfolder = "text_encoder" )
143- text_encoder_2 = T5EncoderModel .from_pretrained (model_id , subfolder = "text_encoder_2" )
142+ text_encoder = CLIPTextModel .from_pretrained (model_id , subfolder = "text_encoder" , torch_dtype = torch .bfloat16 )
143+ text_encoder_2 = T5EncoderModel .from_pretrained (
144+ model_id , subfolder = "text_encoder_2" , torch_dtype = torch .bfloat16
145+ )
144146 tokenizer = CLIPTokenizer .from_pretrained (model_id , subfolder = "tokenizer" )
145147 tokenizer_2 = AutoTokenizer .from_pretrained (model_id , subfolder = "tokenizer_2" )
146- vae = AutoencoderKL .from_pretrained (model_id , subfolder = "vae" )
148+ vae = AutoencoderKL .from_pretrained (model_id , subfolder = "vae" , torch_dtype = torch . bfloat16 )
147149 scheduler = FlowMatchEulerDiscreteScheduler ()
148150
149151 return {
@@ -212,7 +214,7 @@ def get_dummy_tensor_inputs(self, device=None, seed: int = 0):
212214 def _test_quant_type (self , quantization_config : TorchAoConfig , expected_slice : List [float ]):
213215 components = self .get_dummy_components (quantization_config )
214216 pipe = FluxPipeline (** components )
215- pipe .to (device = torch_device , dtype = torch . bfloat16 )
217+ pipe .to (device = torch_device )
216218
217219 inputs = self .get_dummy_inputs (torch_device )
218220 output = pipe (** inputs )[0 ]
@@ -276,7 +278,6 @@ def test_int4wo_quant_bfloat16_conversion(self):
276278 self .assertTrue (isinstance (weight , AffineQuantizedTensor ))
277279 self .assertEqual (weight .quant_min , 0 )
278280 self .assertEqual (weight .quant_max , 15 )
279- self .assertTrue (isinstance (weight .layout_type , TensorCoreTiledLayoutType ))
280281
281282 def test_device_map (self ):
282283 """
@@ -341,21 +342,33 @@ def test_device_map(self):
341342
342343 def test_modules_to_not_convert (self ):
343344 quantization_config = TorchAoConfig ("int8_weight_only" , modules_to_not_convert = ["transformer_blocks.0" ])
344- quantized_model = FluxTransformer2DModel .from_pretrained (
345+ quantized_model_with_not_convert = FluxTransformer2DModel .from_pretrained (
345346 "hf-internal-testing/tiny-flux-pipe" ,
346347 subfolder = "transformer" ,
347348 quantization_config = quantization_config ,
348349 torch_dtype = torch .bfloat16 ,
349350 )
350351
351- unquantized_layer = quantized_model .transformer_blocks [0 ].ff .net [2 ]
352+ unquantized_layer = quantized_model_with_not_convert .transformer_blocks [0 ].ff .net [2 ]
352353 self .assertTrue (isinstance (unquantized_layer , torch .nn .Linear ))
353354 self .assertFalse (isinstance (unquantized_layer .weight , AffineQuantizedTensor ))
354355 self .assertEqual (unquantized_layer .weight .dtype , torch .bfloat16 )
355356
356- quantized_layer = quantized_model .proj_out
357+ quantized_layer = quantized_model_with_not_convert .proj_out
357358 self .assertTrue (isinstance (quantized_layer .weight , AffineQuantizedTensor ))
358- self .assertEqual (quantized_layer .weight .layout_tensor .data .dtype , torch .int8 )
359+
360+ quantization_config = TorchAoConfig ("int8_weight_only" )
361+ quantized_model = FluxTransformer2DModel .from_pretrained (
362+ "hf-internal-testing/tiny-flux-pipe" ,
363+ subfolder = "transformer" ,
364+ quantization_config = quantization_config ,
365+ torch_dtype = torch .bfloat16 ,
366+ )
367+
368+ size_quantized_with_not_convert = get_model_size_in_bytes (quantized_model_with_not_convert )
369+ size_quantized = get_model_size_in_bytes (quantized_model )
370+
371+ self .assertTrue (size_quantized < size_quantized_with_not_convert )
359372
360373 def test_training (self ):
361374 quantization_config = TorchAoConfig ("int8_weight_only" )
@@ -406,23 +419,6 @@ def test_torch_compile(self):
406419 # Note: Seems to require higher tolerance
407420 self .assertTrue (np .allclose (normal_output , compile_output , atol = 1e-2 , rtol = 1e-3 ))
408421
409- @staticmethod
410- def _get_memory_footprint (module ):
411- quantized_param_memory = 0.0
412- unquantized_param_memory = 0.0
413-
414- for param in module .parameters ():
415- if param .__class__ .__name__ == "AffineQuantizedTensor" :
416- data , scale , zero_point = param .layout_tensor .get_plain ()
417- quantized_param_memory += data .numel () + data .element_size ()
418- quantized_param_memory += scale .numel () + scale .element_size ()
419- quantized_param_memory += zero_point .numel () + zero_point .element_size ()
420- else :
421- unquantized_param_memory += param .data .numel () * param .data .element_size ()
422-
423- total_memory = quantized_param_memory + unquantized_param_memory
424- return total_memory , quantized_param_memory , unquantized_param_memory
425-
426422 def test_memory_footprint (self ):
427423 r"""
428424 A simple test to check if the model conversion has been done correctly by checking on the
@@ -433,20 +429,18 @@ def test_memory_footprint(self):
433429 transformer_int8wo = self .get_dummy_components (TorchAoConfig ("int8wo" ))["transformer" ]
434430 transformer_bf16 = self .get_dummy_components (None )["transformer" ]
435431
436- total_int4wo , quantized_int4wo , unquantized_int4wo = self ._get_memory_footprint (transformer_int4wo )
437- total_int4wo_gs32 , quantized_int4wo_gs32 , unquantized_int4wo_gs32 = self ._get_memory_footprint (
438- transformer_int4wo_gs32
439- )
440- total_int8wo , quantized_int8wo , unquantized_int8wo = self ._get_memory_footprint (transformer_int8wo )
441- total_bf16 , quantized_bf16 , unquantized_bf16 = self ._get_memory_footprint (transformer_bf16 )
442-
443- self .assertTrue (quantized_bf16 == 0 and total_bf16 == unquantized_bf16 )
444- # int4wo_gs32 has smaller group size, so more groups -> more scales and zero points
445- self .assertTrue (total_int8wo < total_bf16 < total_int4wo_gs32 )
446- # int4 with default group size quantized very few linear layers compared to a smaller group size of 32
447- self .assertTrue (quantized_int4wo < quantized_int4wo_gs32 and unquantized_int4wo > unquantized_int4wo_gs32 )
432+ total_int4wo = get_model_size_in_bytes (transformer_int4wo )
433+ total_int4wo_gs32 = get_model_size_in_bytes (transformer_int4wo_gs32 )
434+ total_int8wo = get_model_size_in_bytes (transformer_int8wo )
435+ total_bf16 = get_model_size_in_bytes (transformer_bf16 )
436+
437+ # Latter has smaller group size, so more groups -> more scales and zero points
438+ self .assertTrue (total_int4wo < total_int4wo_gs32 )
448439 # int8 quantizes more layers compare to int4 with default group size
449- self .assertTrue (quantized_int8wo < quantized_int4wo )
440+ self .assertTrue (total_int8wo < total_int4wo )
441+ # int4wo does not quantize too many layers because of default group size, but for the layers it does
442+ # there is additional overhead of scales and zero points
443+ self .assertTrue (total_bf16 < total_int4wo )
450444
451445 def test_wrong_config (self ):
452446 with self .assertRaises (ValueError ):
@@ -456,7 +450,7 @@ def test_wrong_config(self):
456450# This class is not to be run as a test by itself. See the tests that follow this class
457451@require_torch
458452@require_torch_gpu
459- @require_torchao_version_greater ("0.6 .0" )
453+ @require_torchao_version_greater_or_equal ("0.7 .0" )
460454class TorchAoSerializationTest (unittest .TestCase ):
461455 model_name = "hf-internal-testing/tiny-flux-pipe"
462456 quant_method , quant_method_kwargs = None , None
@@ -565,7 +559,7 @@ class TorchAoSerializationINTA16W8CPUTest(TorchAoSerializationTest):
565559# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
566560@require_torch
567561@require_torch_gpu
568- @require_torchao_version_greater ("0.6 .0" )
562+ @require_torchao_version_greater_or_equal ("0.7 .0" )
569563@slow
570564@nightly
571565class SlowTorchAoTests (unittest .TestCase ):
@@ -581,11 +575,13 @@ def get_dummy_components(self, quantization_config: TorchAoConfig):
581575 quantization_config = quantization_config ,
582576 torch_dtype = torch .bfloat16 ,
583577 )
584- text_encoder = CLIPTextModel .from_pretrained (model_id , subfolder = "text_encoder" )
585- text_encoder_2 = T5EncoderModel .from_pretrained (model_id , subfolder = "text_encoder_2" )
578+ text_encoder = CLIPTextModel .from_pretrained (model_id , subfolder = "text_encoder" , torch_dtype = torch .bfloat16 )
579+ text_encoder_2 = T5EncoderModel .from_pretrained (
580+ model_id , subfolder = "text_encoder_2" , torch_dtype = torch .bfloat16
581+ )
586582 tokenizer = CLIPTokenizer .from_pretrained (model_id , subfolder = "tokenizer" )
587583 tokenizer_2 = AutoTokenizer .from_pretrained (model_id , subfolder = "tokenizer_2" )
588- vae = AutoencoderKL .from_pretrained (model_id , subfolder = "vae" )
584+ vae = AutoencoderKL .from_pretrained (model_id , subfolder = "vae" , torch_dtype = torch . bfloat16 )
589585 scheduler = FlowMatchEulerDiscreteScheduler ()
590586
591587 return {
@@ -617,7 +613,7 @@ def get_dummy_inputs(self, device: torch.device, seed: int = 0):
617613
618614 def _test_quant_type (self , quantization_config , expected_slice ):
619615 components = self .get_dummy_components (quantization_config )
620- pipe = FluxPipeline (** components ). to ( dtype = torch . bfloat16 )
616+ pipe = FluxPipeline (** components )
621617 pipe .enable_model_cpu_offload ()
622618
623619 inputs = self .get_dummy_inputs (torch_device )
0 commit comments