Skip to content

Commit a6a735a

Browse files
committed
fix inf to nan
1 parent 97cc690 commit a6a735a

File tree

2 files changed

+27
-45
lines changed

2 files changed

+27
-45
lines changed

_unittests/ut_validation/test_f8.py

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -342,9 +342,9 @@ def test_inf_nan(self):
342342
0.203125,
343343
0.75,
344344
numpy.nan,
345-
max(CastFloat8.values_e4m3fn)[0],
346-
max(CastFloat8.values_e4m3fn)[0],
347-
min(CastFloat8.values_e4m3fn)[0],
345+
numpy.nan,
346+
numpy.nan,
347+
-numpy.nan,
348348
],
349349
dtype=numpy.float32,
350350
)
@@ -416,7 +416,7 @@ def test_search_e5m2_pow(self):
416416
)
417417

418418
def test_float32_to_fe4m3fn_inf(self):
419-
mx = max(CastFloat8.values_e4m3fn)[0]
419+
mx = numpy.float32(numpy.nan)
420420
v0 = numpy.float32(mx)
421421
v1 = numpy.float32(numpy.inf)
422422
a = search_float32_into_fe4m3(v0)
@@ -429,7 +429,7 @@ def test_float32_to_fe4m3fn_inf(self):
429429
b = float32_to_fe4m3(v1)
430430
self.assertEqual(a, b)
431431

432-
mi = min(CastFloat8.values_e4m3fn)[0]
432+
mi = numpy.float32(-numpy.nan)
433433
v0 = numpy.float32(mi)
434434
v1 = numpy.float32(-numpy.inf)
435435
a = search_float32_into_fe4m3(v0)
@@ -654,30 +654,18 @@ def test_search_float32_into_fe5m2fnuz(self):
654654
)
655655

656656
def test_float32_to_fe4m3fnuz_inf(self):
657-
v0 = numpy.float32(448)
657+
v0 = numpy.float32(numpy.nan)
658658
v1 = numpy.float32(numpy.inf)
659659
a = search_float32_into_fe4m3(v0, uz=True)
660660
b = search_float32_into_fe4m3(v1, uz=True)
661661
self.assertEqual(a, b)
662662

663-
v0 = numpy.float32(448)
664-
v1 = numpy.float32(numpy.inf)
665-
a = float32_to_fe4m3(v0, uz=True)
666-
b = float32_to_fe4m3(v1, uz=True)
667-
self.assertEqual(a, b)
668-
669-
v0 = numpy.float32(-448)
663+
v0 = numpy.float32(-numpy.nan)
670664
v1 = numpy.float32(-numpy.inf)
671665
a = search_float32_into_fe4m3(v0, uz=True)
672666
b = search_float32_into_fe4m3(v1, uz=True)
673667
self.assertEqual(a, b)
674668

675-
v0 = numpy.float32(-448)
676-
v1 = numpy.float32(-numpy.inf)
677-
a = float32_to_fe4m3(v0, uz=True)
678-
b = float32_to_fe4m3(v1, uz=True)
679-
self.assertEqual(a, b)
680-
681669
v0 = numpy.float32(numpy.nan)
682670
v1 = numpy.float32(-numpy.nan)
683671
a = search_float32_into_fe4m3(v0, uz=True)
@@ -688,7 +676,7 @@ def test_float32_to_fe4m3fnuz_inf(self):
688676
v1 = numpy.float32(-numpy.inf)
689677
a = search_float32_into_fe4m3(v0, uz=True)
690678
b = search_float32_into_fe4m3(v1, uz=True)
691-
self.assertNotEqual(a, b)
679+
self.assertEqual(a, b)
692680

693681
v0 = numpy.float32(numpy.nan)
694682
v1 = numpy.float32(-numpy.nan)
@@ -700,10 +688,10 @@ def test_float32_to_fe4m3fnuz_inf(self):
700688
v1 = numpy.float32(-numpy.inf)
701689
a = float32_to_fe4m3(v0, uz=True)
702690
b = float32_to_fe4m3(v1, uz=True)
703-
self.assertNotEqual(a, b)
691+
self.assertEqual(a, b)
704692

705693
def test_float32_to_fe5m2fnuz_inf(self):
706-
mx = max(CastFloat8.values_e5m2fnuz)[0]
694+
mx = numpy.nan
707695
v0 = numpy.float32(mx)
708696
v1 = numpy.float32(numpy.inf)
709697
a = search_float32_into_fe5m2(v0, fn=True, uz=True)
@@ -716,7 +704,7 @@ def test_float32_to_fe5m2fnuz_inf(self):
716704
b = float32_to_fe5m2(v1, fn=True, uz=True)
717705
self.assertEqual(a, b)
718706

719-
mi = min(CastFloat8.values_e5m2fnuz)[0]
707+
mi = numpy.nan
720708
v0 = numpy.float32(mi)
721709
v1 = numpy.float32(-numpy.inf)
722710
a = search_float32_into_fe5m2(v0, fn=True, uz=True)
@@ -739,7 +727,7 @@ def test_float32_to_fe5m2fnuz_inf(self):
739727
v1 = numpy.float32(-numpy.inf)
740728
a = search_float32_into_fe5m2(v0, fn=True, uz=True)
741729
b = search_float32_into_fe5m2(v1, fn=True, uz=True)
742-
self.assertNotEqual(a, b)
730+
self.assertEqual(a, b)
743731

744732
v0 = numpy.float32(numpy.nan)
745733
v1 = numpy.float32(-numpy.nan)
@@ -751,7 +739,7 @@ def test_float32_to_fe5m2fnuz_inf(self):
751739
v1 = numpy.float32(-numpy.inf)
752740
a = float32_to_fe5m2(v0, fn=True, uz=True)
753741
b = float32_to_fe5m2(v1, fn=True, uz=True)
754-
self.assertNotEqual(a, b)
742+
self.assertEqual(a, b)
755743

756744
def test_simple_fe4m3(self):
757745
values = [448]
@@ -780,7 +768,7 @@ def test_inf_nan_ml_dtypes(self):
780768
g2 = float32_to_fe5m2(x)
781769
i1 = fe4m3_to_float32(g1)
782770
i2 = fe5m2_to_float32(g2)
783-
self.assertEqual(i1, 448)
771+
self.assertNotEqual(i1, 448)
784772
self.assertTrue(numpy.isinf(i2))
785773
m1 = new_cvt_float32_to_e4m3fn(x)
786774
m2 = new_cvt_float32_to_e5m2(x)

onnx_array_api/validation/f8.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import struct
22
import numpy
33

4+
# display functions
45

56
def display_float32(value, sign=1, exponent=8, mantissa=23):
67
"""
@@ -90,6 +91,9 @@ def display_fe5m2(value, sign=1, exponent=4, mantissa=3):
9091
return display_fexmx(value, sign=1, exponent=5, mantissa=2)
9192

9293

94+
# cast from float 8 to float 32
95+
96+
9397
def fe4m3_to_float32_float(ival: int, fn: bool = True, uz: bool = False) -> float:
9498
"""
9599
Casts a float 8 encoded as an integer into a float.
@@ -243,7 +247,6 @@ def fe4m3_to_float32(ival: int, fn: bool = True, uz: bool = False) -> float:
243247
f = numpy.uint32(res).view(numpy.float32) # pylint: disable=E1121
244248
return f
245249

246-
247250
def fe5m2_to_float32(ival: int, fn: bool = False, uz: bool = False) -> float:
248251
"""
249252
Casts a float E5M2 encoded as an integer into a float.
@@ -292,6 +295,7 @@ def fe5m2_to_float32(ival: int, fn: bool = False, uz: bool = False) -> float:
292295
f = numpy.uint32(res).view(numpy.float32) # pylint: disable=E1121
293296
return f
294297

298+
# cast from float32 to float 8
295299

296300
class CastFloat8:
297301
"""
@@ -378,16 +382,12 @@ def search_float32_into_fe4m3(value: float, fn: bool = True, uz: bool = False) -
378382
b = int.from_bytes(struct.pack("<f", numpy.float32(value)), "little")
379383
ret = (b & 0x80000000) >> 24 # sign
380384
if uz:
381-
if numpy.isnan(value):
385+
if numpy.isnan(value) or numpy.isinf(value):
382386
return 0x80
383-
if numpy.isinf(value):
384-
return ret | 0x7F
385387
set_values = CastFloat8.values_e4m3fnuz
386388
else:
387-
if numpy.isnan(value):
389+
if numpy.isnan(value) or numpy.isinf(value):
388390
return 0x7F | ret
389-
if numpy.isinf(value):
390-
return 0x7E | ret
391391
set_values = CastFloat8.values_e4m3fn
392392
f = numpy.float32(value)
393393
i = CastFloat8.find_closest_value(f, set_values)
@@ -407,10 +407,8 @@ def search_float32_into_fe5m2(value: float, fn: bool = False, uz: bool = False)
407407
ret = (b & 0x80000000) >> 24 # sign
408408

409409
if fn and uz:
410-
if numpy.isnan(value):
410+
if numpy.isnan(value) or numpy.isinf(value):
411411
return 0x80
412-
if numpy.isinf(value):
413-
return ret | 0x7F
414412
set_values = CastFloat8.values_e5m2fnuz
415413
elif not fn and not uz:
416414
if numpy.isnan(value):
@@ -438,10 +436,8 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False):
438436
b = int.from_bytes(struct.pack("<f", numpy.float32(x)), "little")
439437
ret = (b & 0x80000000) >> 24 # sign
440438
if uz:
441-
if (b & 0x7FC00000) == 0x7FC00000:
439+
if (b & 0x7FC00000) == 0x7FC00000 or numpy.isinf(x):
442440
return 0x80
443-
if numpy.isinf(x):
444-
return ret | 0x7F # saturation
445441
e = (b & 0x7F800000) >> 23 # exponent
446442
m = b & 0x007FFFFF # mantissa
447443

@@ -475,10 +471,8 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False):
475471
ret |= 0x7F # 01111110
476472
return int(ret)
477473
else:
478-
if (b & 0x7FC00000) == 0x7FC00000:
474+
if (b & 0x7FC00000) == 0x7FC00000 or numpy.isinf(x):
479475
return 0x7F | ret
480-
if numpy.isinf(x):
481-
return 0x7E | ret # saturation
482476
e = (b & 0x7F800000) >> 23 # exponent
483477
m = b & 0x007FFFFF # mantissa
484478

@@ -528,10 +522,10 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False):
528522
ret = (b & 0x80000000) >> 24 # sign
529523

530524
if fn and uz:
531-
if (b & 0x7FC00000) == 0x7FC00000:
525+
if (b & 0x7FC00000) == 0x7FC00000: # NaN
526+
return 0x80
527+
if (b & 0x7FFFFFFF) == 0x7F800000: # Inf
532528
return 0x80
533-
if (b & 0x7FFFFFFF) == 0x7F800000:
534-
return ret | 0x7F
535529
e = (b & 0x7F800000) >> 23 # exponent
536530
m = b & 0x007FFFFF # mantissa
537531

0 commit comments

Comments
 (0)