Skip to content

Commit 9dfb31c

Browse files
committed
update
1 parent db4dd63 commit 9dfb31c

File tree

3 files changed

+32
-48
lines changed

3 files changed

+32
-48
lines changed

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@
2828
)
2929
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
3030
from compressed_tensors.quantization.utils.mxfp4_utils import (
31-
maybe_convert_from_mxfp4_scale,
32-
maybe_convert_to_mxfp4_scales,
31+
generate_mxfp4_scales,
32+
maybe_convert_from_mxfp4_exp,
33+
should_generatre_mxfp4_scales,
3334
)
3435
from compressed_tensors.utils import deprecated
3536
from loguru import logger
@@ -92,8 +93,10 @@ def calculate_qparams(
9293
# 1. Generate scale and zero-point
9394
if quantization_args.symmetric:
9495
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
95-
# scales = max_val_pos / (float(bit_range) / 2)
96-
scales = maybe_convert_to_mxfp4_scales(max_val_pos)
96+
if should_generatre_mxfp4_scales(args=quantization_args):
97+
scales = generate_mxfp4_scales(x=max_val_pos)
98+
else:
99+
scales = max_val_pos / (float(bit_range) / 2)
97100
zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
98101
else:
99102
if (
@@ -117,10 +120,10 @@ def calculate_qparams(
117120
scales, dtype=quantization_args.scale_dtype
118121
)
119122

120-
# Optionally remove exponent
121-
scales = maybe_convert_from_mxfp4_scale(quantization_args, scales)
123+
# 4. Optionally remove exponent
124+
scales = maybe_convert_from_mxfp4_exp(quantization_args, scales)
122125

123-
# 4. Update any 0s with small values to
126+
# 5. Update any 0s with small values to
124127
# prevent div by 0
125128
eps = _get_dtype_eps(
126129
dtype=quantization_args.scale_dtype
@@ -133,7 +136,7 @@ def calculate_qparams(
133136
scales,
134137
)
135138

136-
# 5. Round the zp to zp_dtype
139+
# 6. Round the zp to zp_dtype
137140
zero_points = round_to_quantized_type_dtype(
138141
zero_points, dtype=quantization_args.zp_dtype, cast_to_original_dtype=False
139142
)

src/compressed_tensors/quantization/utils/mxfp4_utils.py

Lines changed: 14 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,21 @@
2121

2222

2323
__all__ = [
24-
"maybe_convert_from_mxfp4_scale",
24+
"maybe_convert_from_mxfp4_exp",
2525
"generate_mxfp4_scales",
2626
"round_to_power_2",
27-
"maybe_convert_to_mxfp4_scales",
27+
"should_generatre_mxfp4_scales",
2828
]
2929

3030
# Reference: https://github.com/vllm-project/vllm/blob/main/tests/quantization/reference_mxfp4.py # noqa: E501
3131

3232

33-
def maybe_convert_from_mxfp4_scale(
34-
args: QuantizationArgs, scale: torch.Tensor, dtype: torch.dtype = torch.bfloat16
33+
def should_generatre_mxfp4_scales(args: QuantizationArgs):
34+
return args.num_bits == 4 and args.type == "float" and args.group_size == 32
35+
36+
37+
def maybe_convert_from_mxfp4_exp(
38+
args: QuantizationArgs, scale: torch.Tensor
3539
) -> torch.Tensor:
3640
"""
3741
Converts mxfp4 scales. Scales are powers of 2, with the
@@ -41,30 +45,14 @@ def maybe_convert_from_mxfp4_scale(
4145
:param scale: uint8 exponent scale
4246
:param dtype: dense dtype
4347
"""
44-
is_mxfp4 = args.num_bits == 4 and args.type == "float" and args.group_size == 32
45-
if is_mxfp4:
48+
original_dtype = scale.dtype
49+
if should_generatre_mxfp4_scales(args):
4650
scale_exp = scale.to(torch.int32) - 127
4751
scale = 2.00 ** (scale_exp.to(torch.float))
48-
return scale.to(dtype)
52+
return scale.to(original_dtype)
4953
return scale
5054

5155

52-
def maybe_convert_to_mxfp4_scales(
53-
args: QuantizationArgs, scales: torch.Tensor
54-
) -> torch.Tensor:
55-
"""
56-
Conver the scales to be mxfp4 compatible scales, if quant args are FP4 with group_size 32.
57-
If not, return original scales
58-
59-
:param args: quantization args
60-
:param scales: scales to update
61-
"""
62-
is_mxfp4 = args.num_bits == 4 and args.type == "float" and args.group_size == 32
63-
if is_mxfp4:
64-
return generate_mxfp4_scales(x=scales)
65-
return scales
66-
67-
6856
def round_to_power_2(x: torch.Tensor) -> torch.Tensor:
6957
"""
7058
Round values to the closest power of 2.
@@ -99,28 +87,17 @@ def round_to_power_2(x: torch.Tensor) -> torch.Tensor:
9987
return block_max_uint.to(torch.uint16).view(torch.bfloat16)
10088

10189

102-
def generate_mxfp4_scales(x: torch.Tensor, clamp: bool = False) -> torch.Tensor:
90+
def generate_mxfp4_scales(x: torch.Tensor) -> torch.Tensor:
10391
"""
10492
Generate mxfp4 scales. The scales require the following steps
10593
1. Round to the closest power of 2
10694
2. Convert to exponent
107-
3. Optionally, store in uint8
10895
10996
Called when calculating qparams using observers.
11097
11198
:param x: tensor to round to closest power of 2
112-
:returns uint8 scales as exponents
99+
:returns scales as exponents
113100
"""
114101
# Round to closest power of 2
115102
scale_power_2 = round_to_power_2(x)
116-
# Convert to exponent
117-
scale_exp = 127 + torch.floor(torch.log2(scale_power_2)).to(torch.int32) - 2
118-
# Clamp and store in uint8, as expected by mxfp4
119-
if clamp:
120-
scale_exp = torch.clamp(
121-
scale_exp,
122-
max=torch.iinfo(torch.uint8).max,
123-
min=torch.iinfo(torch.uint8).min,
124-
)
125-
return scale_exp.to(torch.uint8)
126-
return scale_exp
103+
return 127 + torch.floor(torch.log2(scale_power_2)) - 2

tests/test_quantization/test_utils/test_mxfp4_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
# limitations under the License.
1414

1515
import torch
16+
from compressed_tensors.quantization import round_to_quantized_type_dtype
1617
from compressed_tensors.quantization.utils import (
1718
generate_mxfp4_scales,
18-
maybe_convert_from_mxfp4_scale,
19+
maybe_convert_from_mxfp4_exp,
1920
round_to_power_2,
2021
)
2122

@@ -77,7 +78,6 @@ def test_mxfp4_scales_e2e():
7778
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
7879
block_max = torch.max(torch.abs(min_vals), torch.abs(max_vals))
7980

80-
scales_generated = generate_mxfp4_scales(block_max, clamp=True)
8181
args = QuantizationArgs(
8282
num_bits=4,
8383
type=QuantizationType.FLOAT,
@@ -86,7 +86,11 @@ def test_mxfp4_scales_e2e():
8686
scale_dtype=torch.uint8,
8787
zp_dtype=torch.uint8,
8888
)
89-
converted_ct = maybe_convert_from_mxfp4_scale(args=args, scale=scales_generated)
89+
90+
scales = generate_mxfp4_scales(block_max)
91+
scales = round_to_quantized_type_dtype(scales, dtype=args.scale_dtype)
92+
93+
converted_ct = maybe_convert_from_mxfp4_exp(args=args, scale=scales)
9094

9195
scales_exp = torch.log2(converted_ct)
9296
block_max_exp = torch.floor(torch.log2(round_to_power_2(block_max))) - 2

0 commit comments

Comments
 (0)