Skip to content

Commit 97cc690

Browse files
committed
check with ml_dtypes
" "
1 parent 448a0f2 commit 97cc690

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

_unittests/ut_validation/test_f8.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,23 @@
2020
search_float32_into_fe5m2,
2121
)
2222
from onnx_array_api.ext_test_case import ExtTestCase
23+
from ml_dtypes import float8_e4m3fn, float8_e5m2
24+
25+
26+
def new_cvt_float32_to_e4m3fn(x):
27+
return numpy.array(x, dtype=numpy.float32).astype(float8_e4m3fn)
28+
29+
30+
def new_cvt_e4m3fn_to_float32(x):
31+
return numpy.array(x, dtype=float8_e4m3fn).astype(numpy.float32)
32+
33+
34+
def new_cvt_float32_to_e5m2(x):
35+
return numpy.array(x, dtype=numpy.float32).astype(float8_e5m2)
36+
37+
38+
def new_cvt_e5m2_to_float32(x):
39+
return numpy.array(x, dtype=float8_e5m2).astype(numpy.float32)
2340

2441

2542
class TestF8(ExtTestCase):
@@ -76,6 +93,17 @@ def test_fe4m3fn_to_float32_all(self):
7693
continue
7794
self.assertEqual(a, b)
7895

96+
def test_fe4m3fn_to_float32_all_ml_types(self):
97+
for i in range(0, 256):
98+
a = fe4m3_to_float32_float(i)
99+
b = fe4m3_to_float32(i)
100+
c = new_cvt_float32_to_e4m3fn(b)
101+
if numpy.isnan(a):
102+
self.assertTrue(numpy.isnan(b))
103+
continue
104+
self.assertEqual(float(a), float(c))
105+
self.assertEqual(a, b)
106+
79107
def test_display_float(self):
80108
f = 45
81109
s = display_float32(f)
@@ -164,6 +192,7 @@ def test_search_float32_into_fe5m2_equal(self):
164192
):
165193
b = search_float32_into_fe5m2(value)
166194
nf = float32_to_fe5m2(value)
195+
cf = new_cvt_float32_to_e5m2(value)
167196
if expected in {253, 254, 255, 125, 126, 127}: # nan
168197
self.assertIn(b, {253, 254, 255, 125, 126, 127})
169198
self.assertIn(nf, {253, 254, 255, 125, 126, 127})
@@ -173,6 +202,10 @@ def test_search_float32_into_fe5m2_equal(self):
173202
else:
174203
self.assertIn(b, (0, 128))
175204
self.assertIn(nf, (0, 128))
205+
if numpy.isnan(float(cf)):
206+
self.assertTrue(numpy.isnan(fe5m2_to_float32(nf)))
207+
continue
208+
self.assertEqual(fe5m2_to_float32(nf), float(cf))
176209

177210
def test_search_float32_into_fe4m3fn(self):
178211
values = [(fe4m3_to_float32_float(i), i) for i in range(0, 256)]
@@ -739,6 +772,33 @@ def test_simple_fe4m3(self):
739772
back = [fe4m3_to_float32(c, uz=True) for c in cvt]
740773
self.assertEqual(values, back)
741774

775+
# ml-dtypes
776+
777+
def test_inf_nan_ml_dtypes(self):
778+
x = numpy.float32(numpy.inf)
779+
g1 = float32_to_fe4m3(x)
780+
g2 = float32_to_fe5m2(x)
781+
i1 = fe4m3_to_float32(g1)
782+
i2 = fe5m2_to_float32(g2)
783+
self.assertEqual(i1, 448)
784+
self.assertTrue(numpy.isinf(i2))
785+
m1 = new_cvt_float32_to_e4m3fn(x)
786+
m2 = new_cvt_float32_to_e5m2(x)
787+
self.assertTrue(numpy.isnan(m1)) # different from ONNX choice
788+
self.assertTrue(numpy.isinf(m2))
789+
790+
x = numpy.float32(numpy.nan)
791+
g1 = float32_to_fe4m3(x)
792+
g2 = float32_to_fe5m2(x)
793+
i1 = fe4m3_to_float32(g1)
794+
i2 = fe5m2_to_float32(g2)
795+
self.assertTrue(numpy.isnan(i1))
796+
self.assertTrue(numpy.isnan(i2))
797+
m1 = new_cvt_float32_to_e4m3fn(x)
798+
m2 = new_cvt_float32_to_e5m2(x)
799+
self.assertTrue(numpy.isnan(m1))
800+
self.assertTrue(numpy.isnan(m2))
801+
742802

743803
if __name__ == "__main__":
744804
TestF8().test_search_float32_into_fe4m3fn_simple()

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ furo
55
isort
66
joblib
77
matplotlib
8+
ml-dtypes
89
onnxruntime
910
pandas
1011
pyquickhelper

0 commit comments

Comments
 (0)