Skip to content

Commit 7dc3c9f

Browse files
committed
Add scale decompression support
Signed-off-by: shanjiaz <[email protected]>
1 parent 0dc048d commit 7dc3c9f

File tree

8 files changed

+138
-45
lines changed

8 files changed

+138
-45
lines changed

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -140,18 +140,10 @@ def compress(
140140
if value is None:
141141
continue
142142

143-
if name.endswith("weight_scale") and self._skip_scale():
144-
continue
145-
146143
compressed_dict[name] = value.to(compression_device)
147144

148145
return compressed_dict
149146

150-
def _skip_scale(self):
151-
from compressed_tensors.compressors import NVFP4PackedCompressor
152-
153-
return isinstance(self, NVFP4PackedCompressor)
154-
155147
def decompress(
156148
self,
157149
path_to_model_or_tensors: Union[str, Path, Dict[str, Any]],

src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121
BaseQuantizationCompressor,
2222
)
2323
from compressed_tensors.config import CompressionFormat
24-
from compressed_tensors.quantization import QuantizationArgs
24+
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
2525
from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
26+
from compressed_tensors.quantization.utils import calculate_qparam_shape
2627
from torch import Tensor
2728

2829

@@ -56,7 +57,6 @@ def compression_param_names(self) -> Tuple[str]:
5657
return (
5758
"weight_packed",
5859
"weight_scale",
59-
"weight_zero_point",
6060
"weight_global_scale",
6161
)
6262

@@ -79,6 +79,24 @@ def compression_param_info(
7979
torch.uint8,
8080
),
8181
}
82+
83+
# Add weight_scale and weight_global_scale for NVFP4/MXFP4
84+
if quantization_args is not None and quantization_args.strategy in [
85+
QuantizationStrategy.GROUP.value,
86+
QuantizationStrategy.TENSOR_GROUP.value,
87+
]:
88+
# Use centralized calculation for consistency and correctness
89+
num_groups, scale_shape = calculate_qparam_shape(
90+
weight_shape, quantization_args
91+
)
92+
output["weight_scale"] = (scale_shape, quantization_args.scale_dtype)
93+
94+
if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP.value:
95+
output["weight_global_scale"] = (
96+
torch.Size((1,)),
97+
torch.float32,
98+
)
99+
82100
return output
83101

84102
def compress_weight(
@@ -104,6 +122,11 @@ def compress_weight(
104122
weight_packed = weight_packed.to(device)
105123
compressed_dict["weight_packed"] = weight_packed
106124
compressed_dict["weight_scale"] = scale.to(quantization_args.scale_dtype)
125+
126+
# Include global_scale if provided (for TENSOR_GROUP strategy)
127+
if global_scale is not None:
128+
compressed_dict["weight_global_scale"] = global_scale
129+
107130
return compressed_dict
108131

109132
def decompress_weight(

src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,17 @@ def compress_weight(
111111
if device is not None:
112112
quantized_weight = quantized_weight.to(device)
113113

114-
return {"weight": quantized_weight}
114+
compressed_dict = {"weight": quantized_weight}
115+
116+
# Include scale, zero_point, and g_idx if they exist
117+
if scale is not None:
118+
compressed_dict["weight_scale"] = scale
119+
if zero_point is not None:
120+
compressed_dict["weight_zero_point"] = zero_point
121+
if g_idx is not None:
122+
compressed_dict["weight_g_idx"] = g_idx
123+
124+
return compressed_dict
115125

116126
def decompress_weight(
117127
self,

src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from compressed_tensors.config import CompressionFormat
2323
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
2424
from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
25-
from compressed_tensors.quantization.utils import can_quantize
25+
from compressed_tensors.quantization.utils import calculate_qparam_shape, can_quantize
2626
from torch import Tensor
2727

2828

@@ -69,20 +69,26 @@ def compression_param_info(
6969
"weight_packed": (torch.Size((weight_shape[0], packed_size)), torch.int32),
7070
"weight_shape": (torch.Size((2,)), torch.int32),
7171
}
72-
if not quantization_args.symmetric and quantization_args.strategy in [
72+
73+
# Add weight_scale - always needed for quantization
74+
if quantization_args.strategy in [
7375
QuantizationStrategy.GROUP.value,
7476
QuantizationStrategy.CHANNEL.value,
7577
]:
76-
zp_factor = (
77-
quantization_args.group_size
78-
if quantization_args.strategy == QuantizationStrategy.GROUP.value
79-
else weight_shape[-1]
78+
# Use centralized calculation for consistency and correctness
79+
num_groups, scale_shape = calculate_qparam_shape(
80+
weight_shape, quantization_args
8081
)
82+
output["weight_scale"] = (scale_shape, quantization_args.scale_dtype)
83+
84+
# Add weight_zero_point for asymmetric quantization
85+
# Zero point has same num_groups as scale, but with packed rows
86+
if not quantization_args.symmetric:
87+
output["weight_zero_point"] = (
88+
torch.Size((packed_size_zp, num_groups)),
89+
torch.int32,
90+
)
8191

82-
output["weight_zero_point"] = (
83-
torch.Size((packed_size_zp, weight_shape[-1] // zp_factor)),
84-
torch.int32,
85-
)
8692
return output
8793

8894
def compress_weight(
@@ -126,22 +132,36 @@ def compress_weight(
126132

127133
packed_weight = pack_to_int32(quantized_weight, quantization_args.num_bits)
128134

129-
weight_shape = torch.tensor(weight.shape)
135+
weight_shape = torch.tensor(weight.shape, dtype=torch.int32)
130136
if device is not None:
131137
packed_weight = packed_weight.to(device)
132138
weight_shape = weight_shape.to(device)
133139

134140
compressed_dict["weight_shape"] = weight_shape
135141
compressed_dict["weight_packed"] = packed_weight
136142

137-
if not quantization_args.symmetric and quantization_args.strategy in [
138-
QuantizationStrategy.GROUP.value,
139-
QuantizationStrategy.CHANNEL.value,
140-
]:
141-
packed_zp = pack_to_int32(
142-
zero_point, quantization_args.num_bits, packed_dim=0
143-
)
144-
compressed_dict["weight_zero_point"] = packed_zp.contiguous()
143+
# Include scale if provided
144+
if scale is not None:
145+
compressed_dict["weight_scale"] = scale
146+
147+
# Include zero_point if provided
148+
if zero_point is not None:
149+
if not quantization_args.symmetric and quantization_args.strategy in [
150+
QuantizationStrategy.GROUP.value,
151+
QuantizationStrategy.CHANNEL.value,
152+
]:
153+
packed_zp = pack_to_int32(
154+
zero_point, quantization_args.num_bits, packed_dim=0
155+
)
156+
compressed_dict["weight_zero_point"] = packed_zp.contiguous()
157+
else:
158+
# For symmetric or other strategies, include unpacked zero_point
159+
compressed_dict["weight_zero_point"] = zero_point
160+
161+
# Include g_idx if provided
162+
if g_idx is not None:
163+
compressed_dict["weight_g_idx"] = g_idx
164+
145165
return compressed_dict
146166

147167
def decompress_weight(
@@ -172,11 +192,13 @@ def decompress_weight(
172192
zero_point is not None
173193
), "Asymmetric quantization requires zero-point values"
174194
original_zp_shape = (original_shape[0], scale.shape[-1])
175-
zero_point = unpack_from_int32(
176-
zero_point, num_bits, original_zp_shape, packed_dim=0
177-
)
178-
# Update the compressed_data dict with the unpacked zero_point
179-
compressed_data["weight_zero_point"] = zero_point
195+
# Only unpack if it's still packed (int32)
196+
if zero_point.dtype == torch.int32:
197+
zero_point = unpack_from_int32(
198+
zero_point, num_bits, original_zp_shape, packed_dim=0
199+
)
200+
# Update the compressed_data dict with the unpacked zero_point
201+
compressed_data["weight_zero_point"] = zero_point
180202

181203
decompressed_weight = dequantize(
182204
x_q=unpacked, scale=scale, zero_point=zero_point, g_idx=g_idx

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
"calculate_qparams",
4949
"generate_gparam",
5050
"strategy_cdiv",
51+
"calculate_qparam_shape",
5152
]
5253

5354
# target the self_attn layer
@@ -448,6 +449,50 @@ def strategy_cdiv(
448449
return dividend
449450

450451

452+
def calculate_qparam_shape(
453+
weight_shape: torch.Size,
454+
quantization_args: QuantizationArgs,
455+
) -> Tuple[int, torch.Size]:
456+
"""
457+
Calculate the number of groups and scale/zero_point shape for quantization.
458+
459+
This centralizes the logic for determining quantization parameter shapes,
460+
ensuring consistency with initialize_qparams and avoiding floor division bugs.
461+
462+
:param weight_shape: shape of the weight tensor to be quantized
463+
:param quantization_args: quantization configuration
464+
:return: tuple of (num_groups, expected_shape) where:
465+
- num_groups: number of quantization groups
466+
- expected_shape: shape for scale/zero_point tensors (weight_shape[0], num_groups)
467+
"""
468+
strategy = quantization_args.strategy
469+
470+
if strategy == QuantizationStrategy.TENSOR:
471+
num_groups = 1
472+
expected_shape = torch.Size((1,))
473+
474+
elif strategy == QuantizationStrategy.CHANNEL:
475+
num_groups = 1
476+
expected_shape = torch.Size((weight_shape[0], 1))
477+
478+
elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
479+
group_size = quantization_args.group_size
480+
if group_size is None:
481+
raise ValueError(f"{strategy} quantization requires group_size to be set")
482+
483+
# Use strategy_cdiv for proper ceiling division and validation
484+
num_groups = strategy_cdiv(weight_shape[-1], group_size, strategy)
485+
expected_shape = torch.Size((weight_shape[0], num_groups))
486+
487+
else:
488+
raise ValueError(
489+
f"Unsupported quantization strategy: {strategy}. "
490+
f"Supported strategies: TENSOR, CHANNEL, GROUP, TENSOR_GROUP"
491+
)
492+
493+
return num_groups, expected_shape
494+
495+
451496
def _get_dtype_eps(dtype: torch.dtype) -> float:
452497
if dtype == FP8_E4M3_DATA.dtype:
453498
return 0.125

tests/test_compressors/model_compressors/test_model_compressor.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def test_composability(tmp_path, sparsity_config, quantization_config):
214214
"linear.row_offsets",
215215
"linear.shape",
216216
"linear.weight_scale",
217+
"linear.weight_zero_point",
217218
},
218219
)
219220
],
@@ -572,9 +573,12 @@ def test_decompress_model(model_stub, comp_stub):
572573
# equivalent to decompressing from disk
573574
assert decompressed.keys() == true_decompressed.keys()
574575
for key in decompressed.keys():
575-
assert (
576-
decompressed[key].dtype == true_decompressed[key].dtype
577-
), f"{key} dtypes not equal"
576+
# Skip dtype check for weight_shape - int32/int64 are functionally equivalent
577+
# torch.Size() works identically with both, old checkpoints use int64, new use int32
578+
if not key.endswith("weight_shape"):
579+
assert (
580+
decompressed[key].dtype == true_decompressed[key].dtype
581+
), f"{key} dtypes not equal"
578582
assert torch.all(
579583
decompressed[key] == true_decompressed[key]
580584
), f"{key} values not equal"

tests/test_compressors/quantized_compressors/test_fp8_quant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ def test_quant_format(strategy, group_size, sc, zp):
8989
dense_state_dict, names_to_scheme=module_name_to_scheme
9090
)
9191

92-
# state_dict params should be the same, minus the zero_point if symmetric
93-
assert len(dense_state_dict) == len(compressed_state_dict) + 1
92+
# state_dict params should be the same (zero_point included even for symmetric)
93+
assert len(dense_state_dict) == len(compressed_state_dict)
9494

9595
# check compressed to int8
9696
assert compressed_state_dict["dummy.weight_scale"].dtype == torch.float32

tests/test_compressors/quantized_compressors/test_int_quant.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,8 @@ def test_quant_format(strategy, symmetric, group_size, sc, zp):
8181
dense_state_dict, names_to_scheme=quantized_modules_to_scheme
8282
)
8383

84-
# state_dict params should be the same, minus the zero_point if symmetric
85-
if symmetric:
86-
assert len(dense_state_dict) == len(compressed_state_dict) + 1
87-
else:
88-
assert len(dense_state_dict) == len(compressed_state_dict)
84+
# state_dict params should be the same (zero_point included even for symmetric)
85+
assert len(dense_state_dict) == len(compressed_state_dict)
8986

9087
# check compressed to int8
9188
assert compressed_state_dict["dummy.weight"].dtype == torch.int8

0 commit comments

Comments
 (0)