2121
2222
2323__all__ = [
24- "maybe_convert_from_mxfp4_scale " ,
24+ "maybe_convert_from_mxfp4_exp " ,
2525 "generate_mxfp4_scales" ,
2626 "round_to_power_2" ,
27- "maybe_convert_to_mxfp4_scales " ,
27+ "should_generatre_mxfp4_scales " ,
2828]
2929
3030# Reference: https://github.com/vllm-project/vllm/blob/main/tests/quantization/reference_mxfp4.py # noqa: E501
3131
3232
33- def maybe_convert_from_mxfp4_scale (
34- args : QuantizationArgs , scale : torch .Tensor , dtype : torch .dtype = torch .bfloat16
33+ def should_generatre_mxfp4_scales (args : QuantizationArgs ):
34+ return args .num_bits == 4 and args .type == "float" and args .group_size == 32
35+
36+
37+ def maybe_convert_from_mxfp4_exp (
38+ args : QuantizationArgs , scale : torch .Tensor
3539) -> torch .Tensor :
3640 """
3741 Converts mxfp4 scales. Scales are powers of 2, with the
@@ -41,30 +45,14 @@ def maybe_convert_from_mxfp4_scale(
4145 :param scale: uint8 exponent scale
4246 :param dtype: dense dtype
4347 """
44- is_mxfp4 = args . num_bits == 4 and args . type == "float" and args . group_size == 32
45- if is_mxfp4 :
48+ original_dtype = scale . dtype
49+ if should_generatre_mxfp4_scales ( args ) :
4650 scale_exp = scale .to (torch .int32 ) - 127
4751 scale = 2.00 ** (scale_exp .to (torch .float ))
48- return scale .to (dtype )
52+ return scale .to (original_dtype )
4953 return scale
5054
5155
52- def maybe_convert_to_mxfp4_scales (
53- args : QuantizationArgs , scales : torch .Tensor
54- ) -> torch .Tensor :
55- """
56- Conver the scales to be mxfp4 compatible scales, if quant args are FP4 with group_size 32.
57- If not, return original scales
58-
59- :param args: quantization args
60- :param scales: scales to update
61- """
62- is_mxfp4 = args .num_bits == 4 and args .type == "float" and args .group_size == 32
63- if is_mxfp4 :
64- return generate_mxfp4_scales (x = scales )
65- return scales
66-
67-
6856def round_to_power_2 (x : torch .Tensor ) -> torch .Tensor :
6957 """
7058 Round values to the closest power of 2.
@@ -99,28 +87,17 @@ def round_to_power_2(x: torch.Tensor) -> torch.Tensor:
9987 return block_max_uint .to (torch .uint16 ).view (torch .bfloat16 )
10088
10189
102- def generate_mxfp4_scales (x : torch .Tensor , clamp : bool = False ) -> torch .Tensor :
90+ def generate_mxfp4_scales (x : torch .Tensor ) -> torch .Tensor :
10391 """
10492 Generate mxfp4 scales. The scales require the following steps
10593 1. Round to the closest power of 2
10694 2. Convert to exponent
107- 3. Optionally, store in uint8
10895
10996 Called when calculating qparams using observers.
11097
11198 :param x: tensor to round to closest power of 2
112- :returns uint8 scales as exponents
99+ :returns scales as exponents
113100 """
114101 # Round to closest power of 2
115102 scale_power_2 = round_to_power_2 (x )
116- # Convert to exponent
117- scale_exp = 127 + torch .floor (torch .log2 (scale_power_2 )).to (torch .int32 ) - 2
118- # Clamp and store in uint8, as expected by mxfp4
119- if clamp :
120- scale_exp = torch .clamp (
121- scale_exp ,
122- max = torch .iinfo (torch .uint8 ).max ,
123- min = torch .iinfo (torch .uint8 ).min ,
124- )
125- return scale_exp .to (torch .uint8 )
126- return scale_exp
103+ return 127 + torch .floor (torch .log2 (scale_power_2 )) - 2
0 commit comments