@@ -342,9 +342,9 @@ def test_inf_nan(self):
342
342
0.203125 ,
343
343
0.75 ,
344
344
numpy .nan ,
345
- max ( CastFloat8 . values_e4m3fn )[ 0 ] ,
346
- max ( CastFloat8 . values_e4m3fn )[ 0 ] ,
347
- min ( CastFloat8 . values_e4m3fn )[ 0 ] ,
345
+ numpy . nan ,
346
+ numpy . nan ,
347
+ - numpy . nan ,
348
348
],
349
349
dtype = numpy .float32 ,
350
350
)
@@ -416,7 +416,7 @@ def test_search_e5m2_pow(self):
416
416
)
417
417
418
418
def test_float32_to_fe4m3fn_inf (self ):
419
- mx = max ( CastFloat8 . values_e4m3fn )[ 0 ]
419
+ mx = numpy . float32 ( numpy . nan )
420
420
v0 = numpy .float32 (mx )
421
421
v1 = numpy .float32 (numpy .inf )
422
422
a = search_float32_into_fe4m3 (v0 )
@@ -429,7 +429,7 @@ def test_float32_to_fe4m3fn_inf(self):
429
429
b = float32_to_fe4m3 (v1 )
430
430
self .assertEqual (a , b )
431
431
432
- mi = min ( CastFloat8 . values_e4m3fn )[ 0 ]
432
+ mi = numpy . float32 ( - numpy . nan )
433
433
v0 = numpy .float32 (mi )
434
434
v1 = numpy .float32 (- numpy .inf )
435
435
a = search_float32_into_fe4m3 (v0 )
@@ -654,30 +654,18 @@ def test_search_float32_into_fe5m2fnuz(self):
654
654
)
655
655
656
656
def test_float32_to_fe4m3fnuz_inf (self ):
657
- v0 = numpy .float32 (448 )
657
+ v0 = numpy .float32 (numpy . nan )
658
658
v1 = numpy .float32 (numpy .inf )
659
659
a = search_float32_into_fe4m3 (v0 , uz = True )
660
660
b = search_float32_into_fe4m3 (v1 , uz = True )
661
661
self .assertEqual (a , b )
662
662
663
- v0 = numpy .float32 (448 )
664
- v1 = numpy .float32 (numpy .inf )
665
- a = float32_to_fe4m3 (v0 , uz = True )
666
- b = float32_to_fe4m3 (v1 , uz = True )
667
- self .assertEqual (a , b )
668
-
669
- v0 = numpy .float32 (- 448 )
663
+ v0 = numpy .float32 (- numpy .nan )
670
664
v1 = numpy .float32 (- numpy .inf )
671
665
a = search_float32_into_fe4m3 (v0 , uz = True )
672
666
b = search_float32_into_fe4m3 (v1 , uz = True )
673
667
self .assertEqual (a , b )
674
668
675
- v0 = numpy .float32 (- 448 )
676
- v1 = numpy .float32 (- numpy .inf )
677
- a = float32_to_fe4m3 (v0 , uz = True )
678
- b = float32_to_fe4m3 (v1 , uz = True )
679
- self .assertEqual (a , b )
680
-
681
669
v0 = numpy .float32 (numpy .nan )
682
670
v1 = numpy .float32 (- numpy .nan )
683
671
a = search_float32_into_fe4m3 (v0 , uz = True )
@@ -688,7 +676,7 @@ def test_float32_to_fe4m3fnuz_inf(self):
688
676
v1 = numpy .float32 (- numpy .inf )
689
677
a = search_float32_into_fe4m3 (v0 , uz = True )
690
678
b = search_float32_into_fe4m3 (v1 , uz = True )
691
- self .assertNotEqual (a , b )
679
+ self .assertEqual (a , b )
692
680
693
681
v0 = numpy .float32 (numpy .nan )
694
682
v1 = numpy .float32 (- numpy .nan )
@@ -700,10 +688,10 @@ def test_float32_to_fe4m3fnuz_inf(self):
700
688
v1 = numpy .float32 (- numpy .inf )
701
689
a = float32_to_fe4m3 (v0 , uz = True )
702
690
b = float32_to_fe4m3 (v1 , uz = True )
703
- self .assertNotEqual (a , b )
691
+ self .assertEqual (a , b )
704
692
705
693
def test_float32_to_fe5m2fnuz_inf (self ):
706
- mx = max ( CastFloat8 . values_e5m2fnuz )[ 0 ]
694
+ mx = numpy . nan
707
695
v0 = numpy .float32 (mx )
708
696
v1 = numpy .float32 (numpy .inf )
709
697
a = search_float32_into_fe5m2 (v0 , fn = True , uz = True )
@@ -716,7 +704,7 @@ def test_float32_to_fe5m2fnuz_inf(self):
716
704
b = float32_to_fe5m2 (v1 , fn = True , uz = True )
717
705
self .assertEqual (a , b )
718
706
719
- mi = min ( CastFloat8 . values_e5m2fnuz )[ 0 ]
707
+ mi = numpy . nan
720
708
v0 = numpy .float32 (mi )
721
709
v1 = numpy .float32 (- numpy .inf )
722
710
a = search_float32_into_fe5m2 (v0 , fn = True , uz = True )
@@ -739,7 +727,7 @@ def test_float32_to_fe5m2fnuz_inf(self):
739
727
v1 = numpy .float32 (- numpy .inf )
740
728
a = search_float32_into_fe5m2 (v0 , fn = True , uz = True )
741
729
b = search_float32_into_fe5m2 (v1 , fn = True , uz = True )
742
- self .assertNotEqual (a , b )
730
+ self .assertEqual (a , b )
743
731
744
732
v0 = numpy .float32 (numpy .nan )
745
733
v1 = numpy .float32 (- numpy .nan )
@@ -751,7 +739,7 @@ def test_float32_to_fe5m2fnuz_inf(self):
751
739
v1 = numpy .float32 (- numpy .inf )
752
740
a = float32_to_fe5m2 (v0 , fn = True , uz = True )
753
741
b = float32_to_fe5m2 (v1 , fn = True , uz = True )
754
- self .assertNotEqual (a , b )
742
+ self .assertEqual (a , b )
755
743
756
744
def test_simple_fe4m3 (self ):
757
745
values = [448 ]
@@ -780,7 +768,7 @@ def test_inf_nan_ml_dtypes(self):
780
768
g2 = float32_to_fe5m2 (x )
781
769
i1 = fe4m3_to_float32 (g1 )
782
770
i2 = fe5m2_to_float32 (g2 )
783
- self .assertEqual (i1 , 448 )
771
+ self .assertNotEqual (i1 , 448 )
784
772
self .assertTrue (numpy .isinf (i2 ))
785
773
m1 = new_cvt_float32_to_e4m3fn (x )
786
774
m2 = new_cvt_float32_to_e5m2 (x )
0 commit comments