Skip to content

Commit 448a0f2

Browse files
committed
Adds validation for float 8
1 parent 50c4d52 commit 448a0f2

File tree

2 files changed

+44
-18
lines changed

2 files changed

+44
-18
lines changed

_unittests/ut_validation/test_f8.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,10 @@ def test_search_float32_into_fe5m2(self):
237237
else:
238238
add = v - value
239239
if len(w) > 0:
240-
raise AssertionError(f"A warning was thrown for v={v}, value={value}, w={w[0]}.")
240+
raise AssertionError(
241+
f"A warning was thrown for v={v}, "
242+
f"value={value}, w={w[0]}."
243+
)
241244
else:
242245
v = value + add
243246
b = search_float32_into_fe5m2(v)
@@ -306,9 +309,9 @@ def test_inf_nan(self):
306309
0.203125,
307310
0.75,
308311
numpy.nan,
309-
numpy.nan,
310-
-numpy.nan,
311-
-numpy.nan,
312+
max(CastFloat8.values_e4m3fn)[0],
313+
max(CastFloat8.values_e4m3fn)[0],
314+
min(CastFloat8.values_e4m3fn)[0],
312315
],
313316
dtype=numpy.float32,
314317
)
@@ -380,26 +383,27 @@ def test_search_e5m2_pow(self):
380383
)
381384

382385
def test_float32_to_fe4m3fn_inf(self):
383-
mx =
384-
v0 = numpy.float32(448)
386+
mx = max(CastFloat8.values_e4m3fn)[0]
387+
v0 = numpy.float32(mx)
385388
v1 = numpy.float32(numpy.inf)
386389
a = search_float32_into_fe4m3(v0)
387390
b = search_float32_into_fe4m3(v1)
388391
self.assertEqual(a, b)
389392

390-
v0 = numpy.float32(448)
393+
v0 = numpy.float32(mx)
391394
v1 = numpy.float32(numpy.inf)
392395
a = float32_to_fe4m3(v0)
393396
b = float32_to_fe4m3(v1)
394397
self.assertEqual(a, b)
395398

396-
v0 = numpy.float32(-448)
399+
mi = min(CastFloat8.values_e4m3fn)[0]
400+
v0 = numpy.float32(mi)
397401
v1 = numpy.float32(-numpy.inf)
398402
a = search_float32_into_fe4m3(v0)
399403
b = search_float32_into_fe4m3(v1)
400404
self.assertEqual(a, b)
401405

402-
v0 = numpy.float32(-448)
406+
v0 = numpy.float32(mi)
403407
v1 = numpy.float32(-numpy.inf)
404408
a = float32_to_fe4m3(v0)
405409
b = float32_to_fe4m3(v1)
@@ -666,18 +670,32 @@ def test_float32_to_fe4m3fnuz_inf(self):
666670
self.assertNotEqual(a, b)
667671

668672
def test_float32_to_fe5m2fnuz_inf(self):
669-
v0 = numpy.float32(65536)
673+
mx = max(CastFloat8.values_e5m2fnuz)[0]
674+
v0 = numpy.float32(mx)
670675
v1 = numpy.float32(numpy.inf)
671676
a = search_float32_into_fe5m2(v0, fn=True, uz=True)
672677
b = search_float32_into_fe5m2(v1, fn=True, uz=True)
673678
self.assertEqual(a, b)
674679

675-
v0 = numpy.float32(65536)
680+
v0 = numpy.float32(mx)
676681
v1 = numpy.float32(numpy.inf)
677682
a = float32_to_fe5m2(v0, fn=True, uz=True)
678683
b = float32_to_fe5m2(v1, fn=True, uz=True)
679684
self.assertEqual(a, b)
680685

686+
mi = min(CastFloat8.values_e5m2fnuz)[0]
687+
v0 = numpy.float32(mi)
688+
v1 = numpy.float32(-numpy.inf)
689+
a = search_float32_into_fe5m2(v0, fn=True, uz=True)
690+
b = search_float32_into_fe5m2(v1, fn=True, uz=True)
691+
self.assertEqual(a, b)
692+
693+
v0 = numpy.float32(mi)
694+
v1 = numpy.float32(-numpy.inf)
695+
a = float32_to_fe5m2(v0, fn=True, uz=True)
696+
b = float32_to_fe5m2(v1, fn=True, uz=True)
697+
self.assertEqual(a, b)
698+
681699
v0 = numpy.float32(numpy.nan)
682700
v1 = numpy.float32(-numpy.nan)
683701
a = search_float32_into_fe5m2(v0, fn=True, uz=True)
@@ -688,7 +706,7 @@ def test_float32_to_fe5m2fnuz_inf(self):
688706
v1 = numpy.float32(-numpy.inf)
689707
a = search_float32_into_fe5m2(v0, fn=True, uz=True)
690708
b = search_float32_into_fe5m2(v1, fn=True, uz=True)
691-
self.assertEqual(a, b)
709+
self.assertNotEqual(a, b)
692710

693711
v0 = numpy.float32(numpy.nan)
694712
v1 = numpy.float32(-numpy.nan)
@@ -700,7 +718,7 @@ def test_float32_to_fe5m2fnuz_inf(self):
700718
v1 = numpy.float32(-numpy.inf)
701719
a = float32_to_fe5m2(v0, fn=True, uz=True)
702720
b = float32_to_fe5m2(v1, fn=True, uz=True)
703-
self.assertEqual(a, b)
721+
self.assertNotEqual(a, b)
704722

705723
def test_simple_fe4m3(self):
706724
values = [448]

onnx_array_api/validation/f8.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -378,12 +378,16 @@ def search_float32_into_fe4m3(value: float, fn: bool = True, uz: bool = False) -
378378
b = int.from_bytes(struct.pack("<f", numpy.float32(value)), "little")
379379
ret = (b & 0x80000000) >> 24 # sign
380380
if uz:
381-
if numpy.isnan(value) or numpy.isinf(value):
381+
if numpy.isnan(value):
382382
return 0x80
383+
if numpy.isinf(value):
384+
return ret | 0x7F
383385
set_values = CastFloat8.values_e4m3fnuz
384386
else:
385-
if numpy.isnan(value) or numpy.isinf(value):
387+
if numpy.isnan(value):
386388
return 0x7F | ret
389+
if numpy.isinf(value):
390+
return 0x7E | ret
387391
set_values = CastFloat8.values_e4m3fn
388392
f = numpy.float32(value)
389393
i = CastFloat8.find_closest_value(f, set_values)
@@ -403,8 +407,10 @@ def search_float32_into_fe5m2(value: float, fn: bool = False, uz: bool = False)
403407
ret = (b & 0x80000000) >> 24 # sign
404408

405409
if fn and uz:
406-
if numpy.isnan(value) or numpy.isinf(value):
410+
if numpy.isnan(value):
407411
return 0x80
412+
if numpy.isinf(value):
413+
return ret | 0x7F
408414
set_values = CastFloat8.values_e5m2fnuz
409415
elif not fn and not uz:
410416
if numpy.isnan(value):
@@ -435,7 +441,7 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False):
435441
if (b & 0x7FC00000) == 0x7FC00000:
436442
return 0x80
437443
if numpy.isinf(x):
438-
return 0x80
444+
return ret | 0x7F # saturation
439445
e = (b & 0x7F800000) >> 23 # exponent
440446
m = b & 0x007FFFFF # mantissa
441447

@@ -472,7 +478,7 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False):
472478
if (b & 0x7FC00000) == 0x7FC00000:
473479
return 0x7F | ret
474480
if numpy.isinf(x):
475-
return 0x7F | ret
481+
return 0x7E | ret # saturation
476482
e = (b & 0x7F800000) >> 23 # exponent
477483
m = b & 0x007FFFFF # mantissa
478484

@@ -524,6 +530,8 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False):
524530
if fn and uz:
525531
if (b & 0x7FC00000) == 0x7FC00000:
526532
return 0x80
533+
if (b & 0x7FFFFFFF) == 0x7F800000:
534+
return ret | 0x7F
527535
e = (b & 0x7F800000) >> 23 # exponent
528536
m = b & 0x007FFFFF # mantissa
529537

0 commit comments

Comments
 (0)