@@ -39,20 +39,16 @@ def _get_quant_info(model):
3939 for name , module in model .named_modules ():
4040 with align_module_device (module ):
4141 if is_module_quantized (module ):
42+ # skip zero points, as these are removed between
43+ # compression/decompression for symmetric models
44+
4245 if module .quantization_scheme .weights is not None :
43- quant_info_weights [name ] = (
44- module .weight_scale ,
45- module .weight_zero_point ,
46- module .weight ,
47- )
46+ quant_info_weights [name ] = (module .weight_scale , module .weight )
4847
4948 if module .quantization_scheme .input_activations is not None :
5049 is_dynamic = module .quantization_scheme .input_activations .dynamic
5150 if not is_dynamic :
52- quant_info_inputs [name ] = (
53- module .input_scale ,
54- module .input_zero_point ,
55- )
51+ quant_info_inputs [name ] = (module .input_scale ,)
5652
5753 return quant_info_weights , quant_info_inputs
5854
@@ -110,23 +106,19 @@ def test_quantization_reload(setup_model_and_config):
110106 # TODO: can remove `to` calls after
111107 # https://github.com/neuralmagic/compressed-tensors/pull/427
112108
113- for name , (o_scale , o_zp , o_weight ) in og_weights .items ():
114- n_scale , n_zp , n_weight = reloaded_weights [name ]
109+ for name , (o_scale , o_weight ) in og_weights .items ():
110+ n_scale , n_weight = reloaded_weights [name ]
115111 assert o_scale .dtype == n_scale .dtype == config ["weight_dtype" ]
116112 assert torch .equal (o_scale , n_scale .to (o_scale .device ))
117- assert o_zp .dtype == n_zp .dtype
118- assert torch .equal (o_zp , n_zp .to (o_zp .device ))
119113
120114 # we don't expect an exact match here because o_weight still has the
121115 # original weight and n_weight has been fake_quantized
122116 assert n_weight .dtype == o_weight .dtype == config ["weight_dtype" ]
123117
124- for name , (o_scale , o_zp ) in og_inputs .items ():
125- n_scale , n_zp = reloaded_inputs [name ]
118+ for name , (o_scale ,) in og_inputs .items ():
119+ ( n_scale ,) = reloaded_inputs [name ]
126120 assert o_scale .dtype == n_scale .dtype == config ["weight_dtype" ]
127121 assert torch .equal (o_scale , n_scale .to (o_scale .device ))
128- assert o_zp .dtype == n_zp .dtype
129- assert torch .equal (o_zp , n_zp .to (o_zp .device ))
130122
131123
132124@requires_gpu
0 commit comments