@@ -237,7 +237,10 @@ def test_search_float32_into_fe5m2(self):
237
237
else :
238
238
add = v - value
239
239
if len (w ) > 0 :
240
- raise AssertionError (f"A warning was thrown for v={ v } , value={ value } , w={ w [0 ]} ." )
240
+ raise AssertionError (
241
+ f"A warning was thrown for v={ v } , "
242
+ f"value={ value } , w={ w [0 ]} ."
243
+ )
241
244
else :
242
245
v = value + add
243
246
b = search_float32_into_fe5m2 (v )
@@ -306,9 +309,9 @@ def test_inf_nan(self):
306
309
0.203125 ,
307
310
0.75 ,
308
311
numpy .nan ,
309
- numpy . nan ,
310
- - numpy . nan ,
311
- - numpy . nan ,
312
+ max ( CastFloat8 . values_e4m3fn )[ 0 ] ,
313
+ max ( CastFloat8 . values_e4m3fn )[ 0 ] ,
314
+ min ( CastFloat8 . values_e4m3fn )[ 0 ] ,
312
315
],
313
316
dtype = numpy .float32 ,
314
317
)
@@ -380,26 +383,27 @@ def test_search_e5m2_pow(self):
380
383
)
381
384
382
385
def test_float32_to_fe4m3fn_inf (self ):
383
- mx =
384
- v0 = numpy .float32 (448 )
386
+ mx = max ( CastFloat8 . values_e4m3fn )[ 0 ]
387
+ v0 = numpy .float32 (mx )
385
388
v1 = numpy .float32 (numpy .inf )
386
389
a = search_float32_into_fe4m3 (v0 )
387
390
b = search_float32_into_fe4m3 (v1 )
388
391
self .assertEqual (a , b )
389
392
390
- v0 = numpy .float32 (448 )
393
+ v0 = numpy .float32 (mx )
391
394
v1 = numpy .float32 (numpy .inf )
392
395
a = float32_to_fe4m3 (v0 )
393
396
b = float32_to_fe4m3 (v1 )
394
397
self .assertEqual (a , b )
395
398
396
- v0 = numpy .float32 (- 448 )
399
+ mi = min (CastFloat8 .values_e4m3fn )[0 ]
400
+ v0 = numpy .float32 (mi )
397
401
v1 = numpy .float32 (- numpy .inf )
398
402
a = search_float32_into_fe4m3 (v0 )
399
403
b = search_float32_into_fe4m3 (v1 )
400
404
self .assertEqual (a , b )
401
405
402
- v0 = numpy .float32 (- 448 )
406
+ v0 = numpy .float32 (mi )
403
407
v1 = numpy .float32 (- numpy .inf )
404
408
a = float32_to_fe4m3 (v0 )
405
409
b = float32_to_fe4m3 (v1 )
@@ -666,18 +670,32 @@ def test_float32_to_fe4m3fnuz_inf(self):
666
670
self .assertNotEqual (a , b )
667
671
668
672
def test_float32_to_fe5m2fnuz_inf (self ):
669
- v0 = numpy .float32 (65536 )
673
+ mx = max (CastFloat8 .values_e5m2fnuz )[0 ]
674
+ v0 = numpy .float32 (mx )
670
675
v1 = numpy .float32 (numpy .inf )
671
676
a = search_float32_into_fe5m2 (v0 , fn = True , uz = True )
672
677
b = search_float32_into_fe5m2 (v1 , fn = True , uz = True )
673
678
self .assertEqual (a , b )
674
679
675
- v0 = numpy .float32 (65536 )
680
+ v0 = numpy .float32 (mx )
676
681
v1 = numpy .float32 (numpy .inf )
677
682
a = float32_to_fe5m2 (v0 , fn = True , uz = True )
678
683
b = float32_to_fe5m2 (v1 , fn = True , uz = True )
679
684
self .assertEqual (a , b )
680
685
686
+ mi = min (CastFloat8 .values_e5m2fnuz )[0 ]
687
+ v0 = numpy .float32 (mi )
688
+ v1 = numpy .float32 (- numpy .inf )
689
+ a = search_float32_into_fe5m2 (v0 , fn = True , uz = True )
690
+ b = search_float32_into_fe5m2 (v1 , fn = True , uz = True )
691
+ self .assertEqual (a , b )
692
+
693
+ v0 = numpy .float32 (mi )
694
+ v1 = numpy .float32 (- numpy .inf )
695
+ a = float32_to_fe5m2 (v0 , fn = True , uz = True )
696
+ b = float32_to_fe5m2 (v1 , fn = True , uz = True )
697
+ self .assertEqual (a , b )
698
+
681
699
v0 = numpy .float32 (numpy .nan )
682
700
v1 = numpy .float32 (- numpy .nan )
683
701
a = search_float32_into_fe5m2 (v0 , fn = True , uz = True )
@@ -688,7 +706,7 @@ def test_float32_to_fe5m2fnuz_inf(self):
688
706
v1 = numpy .float32 (- numpy .inf )
689
707
a = search_float32_into_fe5m2 (v0 , fn = True , uz = True )
690
708
b = search_float32_into_fe5m2 (v1 , fn = True , uz = True )
691
- self .assertEqual (a , b )
709
+ self .assertNotEqual (a , b )
692
710
693
711
v0 = numpy .float32 (numpy .nan )
694
712
v1 = numpy .float32 (- numpy .nan )
@@ -700,7 +718,7 @@ def test_float32_to_fe5m2fnuz_inf(self):
700
718
v1 = numpy .float32 (- numpy .inf )
701
719
a = float32_to_fe5m2 (v0 , fn = True , uz = True )
702
720
b = float32_to_fe5m2 (v1 , fn = True , uz = True )
703
- self .assertEqual (a , b )
721
+ self .assertNotEqual (a , b )
704
722
705
723
def test_simple_fe4m3 (self ):
706
724
values = [448 ]
0 commit comments