@@ -66,6 +66,32 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
66
66
67
67
return torch_out , scales
68
68
69
+ def assert_close_percentage (a : torch .Tensor , b : torch .Tensor , mismatch_threshold : float = 0.01 ):
70
+ """
71
+ Assert that two tensors are close within a mismatch percentage.
72
+
73
+ Args:
74
+ a (torch.Tensor): First tensor.
75
+ b (torch.Tensor): Second tensor.
76
+ mismatch_threshold (float): Allowed mismatch ratio (0.01 = 1% mismatch allowed).
77
+
78
+ Raises:
79
+ AssertionError: If mismatch percentage exceeds the threshold.
80
+ """
81
+ if a .shape != b .shape :
82
+ raise AssertionError (f"Shape mismatch: { a .shape } vs { b .shape } " )
83
+
84
+ mismatch_mask = a != b
85
+ mismatch_count = mismatch_mask .sum ().item ()
86
+ total_count = a .numel ()
87
+ mismatch_ratio = mismatch_count / total_count
88
+
89
+ if mismatch_ratio > mismatch_threshold :
90
+ raise AssertionError (
91
+ f"Tensors differ in { mismatch_ratio * 100 :.2f} % of elements "
92
+ f"(allowed { mismatch_threshold * 100 :.2f} %)"
93
+ )
94
+
69
95
def seed_everything (seed ):
70
96
if seed is not None :
71
97
random .seed (seed )
@@ -79,7 +105,7 @@ def seed_everything(seed):
79
105
NUM_TOKENS = [1 , 7 , 83 , 4096 ] # Arbitrary values for testing
80
106
SCALE_UBS = [True , False ]
81
107
SEEDS = [0 ]
82
- FP8_DTYPES = [torch .float8_e5m2 ]
108
+ FP8_DTYPES = [torch .float8_e5m2 , torch . float8_e4m3fn ]
83
109
84
110
85
111
@pytest .mark .parametrize ("num_tokens" , NUM_TOKENS )
@@ -97,7 +123,7 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
97
123
98
124
ref_out , ref_scale = ref_dynamic_per_tensor_fp8_quant (x , fp8_dtype )
99
125
100
- ops_out , ops_scale = scaled_fp8_quant (x )
126
+ ops_out , ops_scale = scaled_fp8_quant (x , fp8_dtype = fp8_dtype )
101
127
102
128
torch .testing .assert_close (ref_scale , ops_scale )
103
129
torch .testing .assert_close (ref_out .to (dtype = torch .float32 ),
@@ -125,11 +151,13 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
125
151
126
152
ops_out , ops_scales = scaled_fp8_quant (x ,
127
153
scale_ub = scale_ub ,
128
- use_per_token_if_dynamic = True )
154
+ use_per_token_if_dynamic = True ,
155
+ fp8_dtype = fp8_dtype )
129
156
130
157
torch .testing .assert_close (ref_scales , ops_scales )
131
- torch .testing .assert_close (ref_out .to (dtype = torch .float32 ),
132
- ops_out .to (dtype = torch .float32 ))
158
+ assert_close_percentage (ref_out .to (dtype = torch .float32 ),
159
+ ops_out .to (dtype = torch .float32 ),
160
+ mismatch_threshold = 0.005 ) # 0.5% mismatch allowed
133
161
134
162
135
163
# Regression test for a case with large activations where an int32 index cannot
@@ -147,7 +175,7 @@ def test_fp8_quant_large(seed: int, fp8_dtype: torch.dtype) -> None:
147
175
x = torch .rand (num_tokens , hidden_size , dtype = dtype , device = "xpu" )
148
176
ref_out , scale = ref_dynamic_per_tensor_fp8_quant (x , fp8_dtype )
149
177
150
- ops_out , _ = scaled_fp8_quant (x , scale )
178
+ ops_out , _ = scaled_fp8_quant (x , scale , fp8_dtype = fp8_dtype )
151
179
152
180
# Minimize memory footprint in this test by freeing x and upconverting
153
181
# the outputs in place. (torch.allclose does not support fp8)
0 commit comments