Skip to content

Commit 35e5442

Browse files
committed
cleanup
Signed-off-by: shanjiaz <[email protected]>
1 parent 30f40b0 commit 35e5442

File tree

6 files changed

+25
-48
lines changed

6 files changed

+25
-48
lines changed

src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -111,17 +111,7 @@ def compress_weight(
111111
if device is not None:
112112
quantized_weight = quantized_weight.to(device)
113113

114-
compressed_dict = {"weight": quantized_weight}
115-
116-
# Include scale, zero_point, and g_idx if they exist
117-
if scale is not None:
118-
compressed_dict["weight_scale"] = scale
119-
if zero_point is not None:
120-
compressed_dict["weight_zero_point"] = zero_point
121-
if g_idx is not None:
122-
compressed_dict["weight_g_idx"] = g_idx
123-
124-
return compressed_dict
114+
return {"weight": quantized_weight}
125115

126116
def decompress_weight(
127117
self,

src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py

Lines changed: 14 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -132,36 +132,22 @@ def compress_weight(
132132

133133
packed_weight = pack_to_int32(quantized_weight, quantization_args.num_bits)
134134

135-
weight_shape = torch.tensor(weight.shape, dtype=torch.int32)
135+
weight_shape = torch.tensor(weight.shape)
136136
if device is not None:
137137
packed_weight = packed_weight.to(device)
138138
weight_shape = weight_shape.to(device)
139139

140140
compressed_dict["weight_shape"] = weight_shape
141141
compressed_dict["weight_packed"] = packed_weight
142142

143-
# Include scale if provided
144-
if scale is not None:
145-
compressed_dict["weight_scale"] = scale
146-
147-
# Include zero_point if provided
148-
if zero_point is not None:
149-
if not quantization_args.symmetric and quantization_args.strategy in [
150-
QuantizationStrategy.GROUP.value,
151-
QuantizationStrategy.CHANNEL.value,
152-
]:
153-
packed_zp = pack_to_int32(
154-
zero_point, quantization_args.num_bits, packed_dim=0
155-
)
156-
compressed_dict["weight_zero_point"] = packed_zp.contiguous()
157-
else:
158-
# For symmetric or other strategies, include unpacked zero_point
159-
compressed_dict["weight_zero_point"] = zero_point
160-
161-
# Include g_idx if provided
162-
if g_idx is not None:
163-
compressed_dict["weight_g_idx"] = g_idx
164-
143+
if not quantization_args.symmetric and quantization_args.strategy in [
144+
QuantizationStrategy.GROUP.value,
145+
QuantizationStrategy.CHANNEL.value,
146+
]:
147+
packed_zp = pack_to_int32(
148+
zero_point, quantization_args.num_bits, packed_dim=0
149+
)
150+
compressed_dict["weight_zero_point"] = packed_zp.contiguous()
165151
return compressed_dict
166152

167153
def decompress_weight(
@@ -192,13 +178,11 @@ def decompress_weight(
192178
zero_point is not None
193179
), "Asymmetric quantization requires zero-point values"
194180
original_zp_shape = (original_shape[0], scale.shape[-1])
195-
# Only unpack if it's still packed (int32)
196-
if zero_point.dtype == torch.int32:
197-
zero_point = unpack_from_int32(
198-
zero_point, num_bits, original_zp_shape, packed_dim=0
199-
)
200-
# Update the compressed_data dict with the unpacked zero_point
201-
compressed_data["weight_zero_point"] = zero_point
181+
zero_point = unpack_from_int32(
182+
zero_point, num_bits, original_zp_shape, packed_dim=0
183+
)
184+
# Update the compressed_data dict with the unpacked zero_point
185+
compressed_data["weight_zero_point"] = zero_point
202186

203187
decompressed_weight = dequantize(
204188
x_q=unpacked, scale=scale, zero_point=zero_point, g_idx=g_idx

tests/test_compressors/model_compressors/test_model_compressor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,6 @@ def test_composability(tmp_path, sparsity_config, quantization_config):
214214
"linear.row_offsets",
215215
"linear.shape",
216216
"linear.weight_scale",
217-
"linear.weight_zero_point",
218217
},
219218
)
220219
],

tests/test_compressors/quantized_compressors/test_fp8_quant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ def test_quant_format(strategy, group_size, sc, zp):
8989
dense_state_dict, names_to_scheme=module_name_to_scheme
9090
)
9191

92-
# state_dict params should be the same (zero_point included even for symmetric)
93-
assert len(dense_state_dict) == len(compressed_state_dict)
92+
# state_dict params should be the same, minus the zero_point if symmetric
93+
assert len(dense_state_dict) == len(compressed_state_dict) + 1
9494

9595
# check compressed to int8
9696
assert compressed_state_dict["dummy.weight_scale"].dtype == torch.float32

tests/test_compressors/quantized_compressors/test_int_quant.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,11 @@ def test_quant_format(strategy, symmetric, group_size, sc, zp):
8181
dense_state_dict, names_to_scheme=quantized_modules_to_scheme
8282
)
8383

84-
# state_dict params should be the same (zero_point included even for symmetric)
85-
assert len(dense_state_dict) == len(compressed_state_dict)
84+
# state_dict params should be the same, minus the zero_point if symmetric
85+
if symmetric:
86+
assert len(dense_state_dict) == len(compressed_state_dict) + 1
87+
else:
88+
assert len(dense_state_dict) == len(compressed_state_dict)
8689

8790
# check compressed to int8
8891
assert compressed_state_dict["dummy.weight"].dtype == torch.int8

tests/test_compressors/quantized_compressors/test_pack_quant.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,9 @@ def test_quant_format(shape):
8888
dense_state_dict, names_to_scheme=quantized_modules_to_scheme
8989
)
9090

91-
# compressed state_dict adds one entry for shape and keeps zero_point
92-
assert len(dense_state_dict) + 1 == len(compressed_state_dict)
91+
# compressed state_dict adds one entry for shape
92+
# but removes the zero points since we are symmetric
93+
assert len(dense_state_dict) == len(compressed_state_dict)
9394

9495
# check compressed and packed
9596
assert compressed_state_dict["dummy.weight_packed"].dtype == torch.int32

0 commit comments

Comments
 (0)