@@ -30,6 +30,42 @@ def ref_dynamic_per_tensor_fp8_quant(x, fp8_dtype=torch.float8_e5m2):
30
30
fp8_traits_min , fp8_traits_max ).to (fp8_dtype )
31
31
return ref_out , ref_scale .view ((1 , ))
32
32
33
+ def ref_dynamic_per_token_quant (x : torch .tensor ,
34
+ quant_dtype : torch .dtype ,
35
+ scale_ub : Optional [torch .tensor ] = None ) \
36
+ -> tuple [torch .tensor , torch .tensor ]:
37
+
38
+ assert quant_dtype in [torch .float8_e5m2 , torch .float8_e4m3fn ]
39
+ # if scale_ub is not None:
40
+ # assert quant_dtype == FP8_DTYPE
41
+
42
+ qtype_traits = torch .finfo (quant_dtype )
43
+ qtype_traits_max = qtype_traits .max
44
+ qtype_traits_min = qtype_traits .min
45
+ qtype_max = as_float32_tensor (qtype_traits_max )
46
+ s_1 = as_float32_tensor (1.0 )
47
+ s_512 = as_float32_tensor (512.0 )
48
+
49
+ # For fp8, in order to match the cuda kernel output, we have to do exactly
50
+ # the same operations as in the corresponding fp8 kernel to prevent
51
+ # rounding errors.
52
+
53
+ # Compute scales
54
+ x_token_max , _ = x .abs ().max (dim = - 1 )
55
+ x_token_max = as_float32_tensor (x_token_max )
56
+ if scale_ub is not None :
57
+ x_token_max = x_token_max .clamp (max = scale_ub )
58
+ scales = (x_token_max / qtype_max )[:, None ]
59
+
60
+ # Quant
61
+ min_scaling_factor = s_1 / (qtype_max * s_512 )
62
+ scales = scales .clamp (min = min_scaling_factor )
63
+ torch_out = as_float32_tensor (x ) / scales
64
+ torch_out = torch_out .clamp (qtype_traits_min ,
65
+ qtype_traits_max ).to (quant_dtype )
66
+
67
+ return torch_out , scales
68
+
33
69
def seed_everything (seed ):
34
70
if seed is not None :
35
71
random .seed (seed )
@@ -68,6 +104,34 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
68
104
ops_out .to (dtype = torch .float32 ))
69
105
70
106
107
+ @pytest .mark .parametrize ("num_tokens" , NUM_TOKENS )
108
+ @pytest .mark .parametrize ("hidden_size" , HIDDEN_SIZES )
109
+ @pytest .mark .parametrize ("dtype" , DTYPES )
110
+ @pytest .mark .parametrize ("scale_ub" , SCALE_UBS )
111
+ @pytest .mark .parametrize ("seed" , SEEDS )
112
+ @pytest .mark .parametrize ("fp8_dtype" , FP8_DTYPES )
113
+ @torch .inference_mode ()
114
+ def test_dynamic_per_token_fp8_quant (num_tokens : int , hidden_size : int ,
115
+ dtype : torch .dtype , scale_ub : bool ,
116
+ seed : int , fp8_dtype : torch .dtype ) -> None :
117
+ seed_everything (seed )
118
+
119
+ x = torch .rand (num_tokens , hidden_size , dtype = dtype ,
120
+ device = "xpu" ) + 1e-6 # avoid nans
121
+
122
+ scale_ub = torch .mean (x ).to (dtype = torch .float32 , device = 'xpu' ) \
123
+ if scale_ub else None
124
+ ref_out , ref_scales = ref_dynamic_per_token_quant (x , fp8_dtype , scale_ub )
125
+
126
+ ops_out , ops_scales = scaled_fp8_quant (x ,
127
+ scale_ub = scale_ub ,
128
+ use_per_token_if_dynamic = True )
129
+
130
+ torch .testing .assert_close (ref_scales , ops_scales )
131
+ torch .testing .assert_close (ref_out .to (dtype = torch .float32 ),
132
+ ops_out .to (dtype = torch .float32 ))
133
+
134
+
71
135
# Regression test for a case with large activations where an int32 index cannot
72
136
# represent the number of elements.
73
137
@torch .inference_mode ()
0 commit comments