Skip to content

Commit b019b89

Browse files
authored
[MXFP4] Add calibration support (#509)
* update * add back test * update * update * fix serialization * fix condition * update * update * update * update * update * remove torch * update * update * update tests * update * update * fix comment * update * update * update * fix typo * update * updatE * rebase fixes * more rebase fix * update * update * update * update * dequant scales not support --------- Signed-off-by: Dipika Sikka <[email protected]>
1 parent 1ec8bb6 commit b019b89

File tree

6 files changed

+128
-28
lines changed

6 files changed

+128
-28
lines changed

src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,14 @@ def compression_param_info(
8181
}
8282
return output
8383

84+
def compress_scale(
85+
self,
86+
scale: Tensor,
87+
quantization_args: QuantizationArgs,
88+
) -> Dict[str, torch.Tensor]:
89+
assert quantization_args.scale_dtype is not None
90+
return scale.to(quantization_args.scale_dtype)
91+
8492
def compress_weight(
8593
self,
8694
weight: Tensor,
@@ -103,7 +111,9 @@ def compress_weight(
103111
if device is not None:
104112
weight_packed = weight_packed.to(device)
105113
compressed_dict["weight_packed"] = weight_packed
106-
compressed_dict["weight_scale"] = scale.to(quantization_args.scale_dtype)
114+
compressed_dict["weight_scale"] = self.compress_scale(
115+
scale=scale, quantization_args=quantization_args
116+
)
107117
return compressed_dict
108118

109119
def decompress_weight(
@@ -130,7 +140,21 @@ class MXFP4PackedCompressor(NVFP4PackedCompressor):
130140
Alias for mxfp4 quantized models
131141
"""
132142

133-
pass
143+
def compress_scale(
144+
self,
145+
scale: Tensor,
146+
quantization_args: QuantizationArgs,
147+
) -> Dict[str, torch.Tensor]:
148+
assert quantization_args.scale_dtype is not None
149+
scale_exp = 127 + torch.floor(torch.log2(scale)).to(torch.int32) - 2
150+
return scale_exp.to(quantization_args.scale_dtype)
151+
152+
def decompress_weight(
153+
self,
154+
compressed_data: Dict[str, Tensor],
155+
quantization_args: Optional[QuantizationArgs] = None,
156+
) -> torch.Tensor:
157+
raise NotImplementedError("MXFP4 Decompression is currently not supported")
134158

135159

136160
@torch.compile(fullgraph=True, dynamic=True)

src/compressed_tensors/config/format.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def _get_quant_compression_format(
5050
is_weight_only = weight_args is not None and input_args is None
5151

5252
if weight_args.num_bits == 4 and weight_args.type == QuantizationType.FLOAT.value:
53+
if weight_args.group_size == 32:
54+
return CompressionFormat.mxfp4_pack_quantized
5355
return CompressionFormat.nvfp4_pack_quantized
5456

5557
if is_weight_only: # w4a16 and w8a16

src/compressed_tensors/quantization/quant_scheme.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
import warnings
1615
from copy import deepcopy
1716
from typing import List, Optional
1817

18+
import torch
1919
from compressed_tensors.config import CompressionFormat
2020
from compressed_tensors.quantization.quant_args import (
2121
FP8_E4M3_DATA,
@@ -192,6 +192,43 @@ def is_preset_scheme(name: str) -> bool:
192192
),
193193
)
194194

195+
MXFP4A16 = dict(
196+
weights=QuantizationArgs(
197+
num_bits=4,
198+
type=QuantizationType.FLOAT,
199+
strategy=QuantizationStrategy.GROUP,
200+
symmetric=True,
201+
dynamic=False,
202+
group_size=32,
203+
scale_dtype=torch.uint8,
204+
zp_dtype=torch.uint8,
205+
)
206+
)
207+
208+
MXFP4 = dict(
209+
weights=QuantizationArgs(
210+
num_bits=4,
211+
type=QuantizationType.FLOAT,
212+
strategy=QuantizationStrategy.GROUP,
213+
symmetric=True,
214+
dynamic=False,
215+
group_size=32,
216+
scale_dtype=torch.uint8,
217+
zp_dtype=torch.uint8,
218+
),
219+
input_activations=QuantizationArgs(
220+
num_bits=4,
221+
type=QuantizationType.FLOAT,
222+
strategy=QuantizationStrategy.GROUP,
223+
dynamic=True,
224+
symmetric=True,
225+
group_size=32,
226+
scale_dtype=torch.uint8,
227+
zp_dtype=torch.uint8,
228+
),
229+
)
230+
231+
195232
# 8 bit integer weights and 8 bit activations quantization
196233
INT8_W8A8 = dict(
197234
weights=QuantizationArgs(
@@ -343,4 +380,6 @@ def is_preset_scheme(name: str) -> bool:
343380
"FP8_BLOCK": FP8_BLOCK,
344381
"NVFP4A16": NVFP4A16,
345382
"NVFP4": NVFP4,
383+
"MXFP4A16": MXFP4A16,
384+
"MXFP4": MXFP4,
346385
}

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@
2727
round_to_quantized_type_dtype,
2828
)
2929
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
30+
from compressed_tensors.quantization.utils.mxfp4_utils import (
31+
generate_mxfp4_scales,
32+
maybe_convert_from_mxfp4_exp,
33+
should_generatre_mxfp4_scales,
34+
)
3035
from compressed_tensors.utils import deprecated
3136
from loguru import logger
3237
from torch import FloatTensor, IntTensor, Tensor
@@ -88,7 +93,10 @@ def calculate_qparams(
8893
# 1. Generate scale and zero-point
8994
if quantization_args.symmetric:
9095
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
91-
scales = max_val_pos / (float(bit_range) / 2)
96+
if should_generatre_mxfp4_scales(args=quantization_args):
97+
scales = generate_mxfp4_scales(x=max_val_pos)
98+
else:
99+
scales = max_val_pos / (float(bit_range) / 2)
92100
zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
93101
else:
94102
if (
@@ -112,7 +120,10 @@ def calculate_qparams(
112120
scales, dtype=quantization_args.scale_dtype
113121
)
114122

115-
# 4. Update any 0s with small values to
123+
# 4. Optionally remove exponent
124+
scales = maybe_convert_from_mxfp4_exp(quantization_args, scales)
125+
126+
# 5. Update any 0s with small values to
116127
# prevent div by 0
117128
eps = _get_dtype_eps(
118129
dtype=quantization_args.scale_dtype
@@ -125,7 +136,7 @@ def calculate_qparams(
125136
scales,
126137
)
127138

128-
# 5. Round the zp to zp_dtype
139+
# 6. Round the zp to zp_dtype
129140
zero_points = round_to_quantized_type_dtype(
130141
zero_points, dtype=quantization_args.zp_dtype, cast_to_original_dtype=False
131142
)

src/compressed_tensors/quantization/utils/mxfp4_utils.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,29 @@
1313
# limitations under the License.
1414

1515
import torch
16-
from compressed_tensors.quantization.quant_args import BFLOAT16_DATA, FP4_E2M1_DATA
16+
from compressed_tensors.quantization.quant_args import (
17+
BFLOAT16_DATA,
18+
FP4_E2M1_DATA,
19+
QuantizationArgs,
20+
)
1721

1822

19-
__all__ = ["convert_mxfp4_exp_scale", "generate_mxfp4_scales", "round_to_power_2"]
23+
__all__ = [
24+
"maybe_convert_from_mxfp4_exp",
25+
"generate_mxfp4_scales",
26+
"round_to_power_2",
27+
"should_generatre_mxfp4_scales",
28+
]
2029

2130
# Reference: https://github.com/vllm-project/vllm/blob/main/tests/quantization/reference_mxfp4.py # noqa: E501
2231

2332

24-
def convert_mxfp4_exp_scale(
25-
scale: torch.Tensor, dtype: torch.dtype = torch.bfloat16
33+
def should_generatre_mxfp4_scales(args: QuantizationArgs):
34+
return args.num_bits == 4 and args.type == "float" and args.group_size == 32
35+
36+
37+
def maybe_convert_from_mxfp4_exp(
38+
args: QuantizationArgs, scale: torch.Tensor
2639
) -> torch.Tensor:
2740
"""
2841
Converts mxfp4 scales. Scales are powers of 2, with the
@@ -32,10 +45,12 @@ def convert_mxfp4_exp_scale(
3245
:param scale: uint8 exponent scale
3346
:param dtype: dense dtype
3447
"""
35-
assert scale.dtype == torch.uint8
36-
scale_exp = scale.to(torch.int32) - 127
37-
scale = 2.00 ** (scale_exp.to(torch.float))
38-
return scale.to(dtype)
48+
original_dtype = scale.dtype
49+
if should_generatre_mxfp4_scales(args):
50+
scale_exp = scale.to(torch.int32) - 127
51+
scale = 2.00 ** (scale_exp.to(torch.float))
52+
return scale.to(original_dtype)
53+
return scale
3954

4055

4156
def round_to_power_2(x: torch.Tensor) -> torch.Tensor:
@@ -77,21 +92,12 @@ def generate_mxfp4_scales(x: torch.Tensor) -> torch.Tensor:
7792
Generate mxfp4 scales. The scales require the following steps
7893
1. Round to the closest power of 2
7994
2. Convert to exponent
80-
3. Store in uint8
8195
8296
Called when calculating qparams using observers.
8397
8498
:param x: tensor to round to closest power of 2
85-
:returns uint8 scales as exponents
99+
:returns scales as exponents
86100
"""
87101
# Round to closest power of 2
88102
scale_power_2 = round_to_power_2(x)
89-
# Convert to exponent
90-
scale_exp = 127 + torch.floor(torch.log2(scale_power_2)).to(torch.int32) - 2
91-
# Clamp and store in uint8, as expected by mxfp4
92-
scale_exp = torch.clamp(
93-
scale_exp,
94-
max=torch.iinfo(torch.uint8).max,
95-
min=torch.iinfo(torch.uint8).min,
96-
)
97-
return scale_exp.to(torch.uint8)
103+
return 127 + torch.floor(torch.log2(scale_power_2)) - 2

tests/test_quantization/test_utils/test_mxfp4_utils.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
# limitations under the License.
1414

1515
import torch
16+
from compressed_tensors.quantization import round_to_quantized_type_dtype
1617
from compressed_tensors.quantization.utils import (
17-
convert_mxfp4_exp_scale,
1818
generate_mxfp4_scales,
19+
maybe_convert_from_mxfp4_exp,
1920
round_to_power_2,
2021
)
2122

@@ -61,6 +62,12 @@ def test_round_power_2():
6162

6263

6364
def test_mxfp4_scales_e2e():
65+
from compressed_tensors.quantization.quant_args import (
66+
QuantizationArgs,
67+
QuantizationStrategy,
68+
QuantizationType,
69+
)
70+
6471
mock_weight = torch.normal(mean=0.0002, std=0.0576, size=(2880, 2880))
6572

6673
x = mock_weight.reshape(*mock_weight.shape[:-1], -1, 32).to(torch.bfloat16)
@@ -71,8 +78,19 @@ def test_mxfp4_scales_e2e():
7178
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
7279
block_max = torch.max(torch.abs(min_vals), torch.abs(max_vals))
7380

74-
scales_generated = generate_mxfp4_scales(block_max)
75-
converted_ct = convert_mxfp4_exp_scale(scales_generated)
81+
args = QuantizationArgs(
82+
num_bits=4,
83+
type=QuantizationType.FLOAT,
84+
strategy=QuantizationStrategy.GROUP,
85+
group_size=32,
86+
scale_dtype=torch.uint8,
87+
zp_dtype=torch.uint8,
88+
)
89+
90+
scales = generate_mxfp4_scales(block_max)
91+
scales = round_to_quantized_type_dtype(scales, dtype=args.scale_dtype)
92+
93+
converted_ct = maybe_convert_from_mxfp4_exp(args=args, scale=scales)
7694

7795
scales_exp = torch.log2(converted_ct)
7896
block_max_exp = torch.floor(torch.log2(round_to_power_2(block_max))) - 2

0 commit comments

Comments
 (0)