20
20
search_float32_into_fe5m2 ,
21
21
)
22
22
from onnx_array_api .ext_test_case import ExtTestCase
23
+ from ml_dtypes import float8_e4m3fn , float8_e5m2
24
+
25
+
26
+ def new_cvt_float32_to_e4m3fn (x ):
27
+ return numpy .array (x , dtype = numpy .float32 ).astype (float8_e4m3fn )
28
+
29
+
30
+ def new_cvt_e4m3fn_to_float32 (x ):
31
+ return numpy .array (x , dtype = float8_e4m3fn ).astype (numpy .float32 )
32
+
33
+
34
+ def new_cvt_float32_to_e5m2 (x ):
35
+ return numpy .array (x , dtype = numpy .float32 ).astype (float8_e5m2 )
36
+
37
+
38
+ def new_cvt_e5m2_to_float32 (x ):
39
+ return numpy .array (x , dtype = float8_e5m2 ).astype (numpy .float32 )
23
40
24
41
25
42
class TestF8 (ExtTestCase ):
@@ -76,6 +93,17 @@ def test_fe4m3fn_to_float32_all(self):
76
93
continue
77
94
self .assertEqual (a , b )
78
95
96
+ def test_fe4m3fn_to_float32_all_ml_types (self ):
97
+ for i in range (0 , 256 ):
98
+ a = fe4m3_to_float32_float (i )
99
+ b = fe4m3_to_float32 (i )
100
+ c = new_cvt_float32_to_e4m3fn (b )
101
+ if numpy .isnan (a ):
102
+ self .assertTrue (numpy .isnan (b ))
103
+ continue
104
+ self .assertEqual (float (a ), float (c ))
105
+ self .assertEqual (a , b )
106
+
79
107
def test_display_float (self ):
80
108
f = 45
81
109
s = display_float32 (f )
@@ -164,6 +192,7 @@ def test_search_float32_into_fe5m2_equal(self):
164
192
):
165
193
b = search_float32_into_fe5m2 (value )
166
194
nf = float32_to_fe5m2 (value )
195
+ cf = new_cvt_float32_to_e5m2 (value )
167
196
if expected in {253 , 254 , 255 , 125 , 126 , 127 }: # nan
168
197
self .assertIn (b , {253 , 254 , 255 , 125 , 126 , 127 })
169
198
self .assertIn (nf , {253 , 254 , 255 , 125 , 126 , 127 })
@@ -173,6 +202,10 @@ def test_search_float32_into_fe5m2_equal(self):
173
202
else :
174
203
self .assertIn (b , (0 , 128 ))
175
204
self .assertIn (nf , (0 , 128 ))
205
+ if numpy .isnan (float (cf )):
206
+ self .assertTrue (numpy .isnan (fe5m2_to_float32 (nf )))
207
+ continue
208
+ self .assertEqual (fe5m2_to_float32 (nf ), float (cf ))
176
209
177
210
def test_search_float32_into_fe4m3fn (self ):
178
211
values = [(fe4m3_to_float32_float (i ), i ) for i in range (0 , 256 )]
@@ -739,6 +772,33 @@ def test_simple_fe4m3(self):
739
772
back = [fe4m3_to_float32 (c , uz = True ) for c in cvt ]
740
773
self .assertEqual (values , back )
741
774
775
+ # ml-dtypes
776
+
777
+ def test_inf_nan_ml_dtypes (self ):
778
+ x = numpy .float32 (numpy .inf )
779
+ g1 = float32_to_fe4m3 (x )
780
+ g2 = float32_to_fe5m2 (x )
781
+ i1 = fe4m3_to_float32 (g1 )
782
+ i2 = fe5m2_to_float32 (g2 )
783
+ self .assertEqual (i1 , 448 )
784
+ self .assertTrue (numpy .isinf (i2 ))
785
+ m1 = new_cvt_float32_to_e4m3fn (x )
786
+ m2 = new_cvt_float32_to_e5m2 (x )
787
+ self .assertTrue (numpy .isnan (m1 )) # different from ONNX choice
788
+ self .assertTrue (numpy .isinf (m2 ))
789
+
790
+ x = numpy .float32 (numpy .nan )
791
+ g1 = float32_to_fe4m3 (x )
792
+ g2 = float32_to_fe5m2 (x )
793
+ i1 = fe4m3_to_float32 (g1 )
794
+ i2 = fe5m2_to_float32 (g2 )
795
+ self .assertTrue (numpy .isnan (i1 ))
796
+ self .assertTrue (numpy .isnan (i2 ))
797
+ m1 = new_cvt_float32_to_e4m3fn (x )
798
+ m2 = new_cvt_float32_to_e5m2 (x )
799
+ self .assertTrue (numpy .isnan (m1 ))
800
+ self .assertTrue (numpy .isnan (m2 ))
801
+
742
802
743
803
if __name__ == "__main__" :
744
804
TestF8 ().test_search_float32_into_fe4m3fn_simple ()
0 commit comments