Skip to content

Commit ac326ee

Browse files
committed
Add zero-point compression for asymmetric quantization
Signed-off-by: shanjiaz <[email protected]>
1 parent 71f34d7 commit ac326ee

File tree

2 files changed

+31
-18
lines changed

2 files changed

+31
-18
lines changed

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,21 @@ def compress(
124124
compressed_dict[prefix + key] = value.to(compression_device)
125125

126126
else:
127-
# omit saving zero points for symmetric or packed quantization
128-
if name.endswith("zero_point") and self._skip_zp(name, names_to_scheme):
129-
continue
127+
# omit saving zero points for symmetric quantization
128+
if name.endswith("weight_zero_point"):
129+
module_path = name.rsplit(".", 1)[0]
130+
if (
131+
module_path in names_to_scheme
132+
and names_to_scheme[module_path].weights.symmetric
133+
):
134+
continue
135+
# Call compress_zp if available (for PackedQuantizationCompressor)
136+
if module_path in names_to_scheme and hasattr(self, "compress_zp"):
137+
value = self.compress_zp(
138+
value, names_to_scheme[module_path].weights
139+
)
140+
if value is None:
141+
continue
130142

131143
if name.endswith("weight_scale") and self._skip_scale():
132144
continue
@@ -140,21 +152,6 @@ def _skip_scale(self):
140152

141153
return isinstance(self, NVFP4PackedCompressor)
142154

143-
def _skip_zp(
144-
self, name: str, names_to_scheme: Dict[str, QuantizationScheme]
145-
) -> bool:
146-
module_name, zp_name = name.rsplit(".", 1) if "." in name else ("", name)
147-
scheme = names_to_scheme[module_name]
148-
149-
if zp_name == "weight_zero_point":
150-
args = scheme.weights
151-
if zp_name == "input_zero_point":
152-
args = scheme.input_activations
153-
if zp_name == "output_zero_point":
154-
args = scheme.output_activations
155-
156-
return args.symmetric
157-
158155
def decompress(
159156
self,
160157
path_to_model_or_tensors: Union[str, Path, Dict[str, Any]],

src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,22 @@ def decompress_weight(
184184

185185
return decompressed_weight
186186

187+
def compress_zp(
188+
self, zero_point: Tensor, quantization_args: Optional[QuantizationArgs] = None
189+
) -> Optional[Tensor]:
190+
if zero_point is None or quantization_args.symmetric:
191+
return None
192+
if zero_point.dtype == torch.int32:
193+
return zero_point
194+
if quantization_args.strategy in [
195+
QuantizationStrategy.GROUP.value,
196+
QuantizationStrategy.CHANNEL.value,
197+
]:
198+
return pack_to_int32(
199+
zero_point, quantization_args.num_bits, packed_dim=0
200+
).contiguous()
201+
return zero_point
202+
187203

188204
def pack_to_int32(
189205
value: torch.Tensor,

0 commit comments

Comments
 (0)