Skip to content

Commit ca3ca7d

Browse files
committed
fix tests
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 4b74e4e commit ca3ca7d

File tree

2 files changed

+18
-17
lines changed

2 files changed

+18
-17
lines changed

tests/llmcompressor/transformers/compression/test_compress_tensor_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def test_quant_model_reload(format, dtype, tmp_path):
8787
quantization_config=CompressedTensorsConfig(run_compressed=False),
8888
)
8989

90+
_remove_zp(og_state_dict) # HACK: remove extra zero points added during quant init
9091
reconstructed_state_dict = decompressed_model.state_dict()
9192
assert len(og_state_dict) == len(reconstructed_state_dict)
9293
for key in og_state_dict.keys():
@@ -275,3 +276,11 @@ def test_correct_compressor_inferred(
275276
model.linear.quantization_status = QuantizationStatus.FROZEN
276277

277278
assert infer_model_format(model) == expected_format
279+
280+
281+
def _remove_zp(state_dict: dict) -> dict:
282+
return {
283+
key: value
284+
for key, value in state_dict.items()
285+
if not key.endswith("zero_point")
286+
}

tests/llmcompressor/transformers/compression/test_quantization.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,16 @@ def _get_quant_info(model):
3939
for name, module in model.named_modules():
4040
with align_module_device(module):
4141
if is_module_quantized(module):
42+
# skip zero points, as these are removed between
43+
# compression/decompression for symmetric models
44+
4245
if module.quantization_scheme.weights is not None:
43-
quant_info_weights[name] = (
44-
module.weight_scale,
45-
module.weight_zero_point,
46-
module.weight,
47-
)
46+
quant_info_weights[name] = (module.weight_scale, module.weight)
4847

4948
if module.quantization_scheme.input_activations is not None:
5049
is_dynamic = module.quantization_scheme.input_activations.dynamic
5150
if not is_dynamic:
52-
quant_info_inputs[name] = (
53-
module.input_scale,
54-
module.input_zero_point,
55-
)
51+
quant_info_inputs[name] = (module.input_scale,)
5652

5753
return quant_info_weights, quant_info_inputs
5854

@@ -110,23 +106,19 @@ def test_quantization_reload(setup_model_and_config):
110106
# TODO: can remove `to` calls after
111107
# https://github.com/neuralmagic/compressed-tensors/pull/427
112108

113-
for name, (o_scale, o_zp, o_weight) in og_weights.items():
114-
n_scale, n_zp, n_weight = reloaded_weights[name]
109+
for name, (o_scale, o_weight) in og_weights.items():
110+
n_scale, n_weight = reloaded_weights[name]
115111
assert o_scale.dtype == n_scale.dtype == config["weight_dtype"]
116112
assert torch.equal(o_scale, n_scale.to(o_scale.device))
117-
assert o_zp.dtype == n_zp.dtype
118-
assert torch.equal(o_zp, n_zp.to(o_zp.device))
119113

120114
# we don't expect an exact match here because o_weight still has the
121115
# original weight and n_weight has been fake_quantized
122116
assert n_weight.dtype == o_weight.dtype == config["weight_dtype"]
123117

124-
for name, (o_scale, o_zp) in og_inputs.items():
125-
n_scale, n_zp = reloaded_inputs[name]
118+
for name, (o_scale,) in og_inputs.items():
119+
(n_scale,) = reloaded_inputs[name]
126120
assert o_scale.dtype == n_scale.dtype == config["weight_dtype"]
127121
assert torch.equal(o_scale, n_scale.to(o_scale.device))
128-
assert o_zp.dtype == n_zp.dtype
129-
assert torch.equal(o_zp, n_zp.to(o_zp.device))
130122

131123

132124
@requires_gpu

0 commit comments

Comments
 (0)