55
55
56
56
ALLOW_NAN = False
57
57
ALLOW_INFINITY = False
58
+ ALLOW_SUBNORMAL = False
58
59
59
60
JAX_MODE = False
60
61
NUMPY_MODE = not JAX_MODE
@@ -81,6 +82,12 @@ def _getattr(obj, name):
81
82
return functools .reduce (getattr , names , obj )
82
83
83
84
85
+ def _maybe_get_subnormal_kwarg (allow_subnormal = ALLOW_SUBNORMAL ):
86
+ if hp .__version_info__ >= (6 , 30 ):
87
+ return {'allow_subnormal' : allow_subnormal }
88
+ return {}
89
+
90
+
84
91
class TestCase (dict ):
85
92
"""`dict` object containing test strategies for a single function."""
86
93
@@ -121,18 +128,21 @@ def floats(draw,
121
128
max_value = 1e16 ,
122
129
allow_nan = ALLOW_NAN ,
123
130
allow_infinity = ALLOW_INFINITY ,
131
+ allow_subnormal = ALLOW_SUBNORMAL ,
124
132
dtype = None ):
125
133
if dtype is None :
126
134
dtype = np .float32 if FLAGS .use_tpu else np .float64
127
135
if min_value is not None :
128
136
min_value = onp .array (min_value , dtype = dtype ).item ()
129
137
if max_value is not None :
130
138
max_value = onp .array (max_value , dtype = dtype ).item ()
139
+ subnormal_kwarg = _maybe_get_subnormal_kwarg (allow_subnormal )
131
140
return draw (hps .floats (min_value = min_value ,
132
141
max_value = max_value ,
133
142
allow_nan = allow_nan ,
134
143
allow_infinity = allow_infinity ,
135
- width = np .dtype (dtype ).itemsize * 8 ))
144
+ width = np .dtype (dtype ).itemsize * 8 ,
145
+ ** subnormal_kwarg ))
136
146
137
147
138
148
def integers (min_value = - 2 ** 30 , max_value = 2 ** 30 ):
@@ -604,11 +614,15 @@ def top_k_params(draw):
604
614
def histogram_fixed_width_bins_params (draw ):
605
615
# TODO(b/187125431): the `min_side=2` and `unique` check can be removed if
606
616
# https://github.com/tensorflow/tensorflow/pull/38899 is re-implemented.
617
+ subnormal_kwarg = _maybe_get_subnormal_kwarg ()
607
618
values = draw (single_arrays (
608
619
dtype = np .float32 ,
609
620
shape = shapes (min_dims = 1 , min_side = 2 ),
610
621
unique = True ,
611
- elements = hps .floats (min_value = - 1e5 , max_value = 1e5 , width = 32 )
622
+ # Avoid intervals containing 0 due to NP/TF discrepancy for bin boundaries
623
+ # near 0.
624
+ elements = hps .floats (min_value = 0. , max_value = 1e10 , width = 32 ,
625
+ ** subnormal_kwarg ),
612
626
))
613
627
vmin , vmax = np .min (values ), np .max (values )
614
628
value_min = draw (hps .one_of (
@@ -699,10 +713,12 @@ def sparse_xent_params(draw):
699
713
shape = hps .just (tuple ()),
700
714
dtype = np .int32 ,
701
715
elements = hps .integers (0 , num_classes - 1 ))
716
+ subnormal_kwarg = _maybe_get_subnormal_kwarg ()
702
717
logits = single_arrays (
703
718
batch_shape = batch_shape ,
704
719
shape = hps .just ((num_classes ,)),
705
- elements = hps .floats (min_value = - 1e5 , max_value = 1e5 , width = 32 ))
720
+ elements = hps .floats (min_value = - 1e5 , max_value = 1e5 , width = 32 ,
721
+ ** subnormal_kwarg ))
706
722
return draw (
707
723
hps .fixed_dictionaries (dict (
708
724
labels = labels , logits = logits )).map (Kwargs ))
@@ -714,10 +730,12 @@ def xent_params(draw):
714
730
batch_shape = draw (shapes (min_dims = 1 ))
715
731
labels = batched_probabilities (
716
732
batch_shape = batch_shape , num_classes = num_classes )
733
+ subnormal_kwarg = _maybe_get_subnormal_kwarg ()
717
734
logits = single_arrays (
718
735
batch_shape = batch_shape ,
719
736
shape = hps .just ((num_classes ,)),
720
- elements = hps .floats (min_value = - 1e5 , max_value = 1e5 , width = 32 ))
737
+ elements = hps .floats (min_value = - 1e5 , max_value = 1e5 , width = 32 ,
738
+ ** subnormal_kwarg ))
721
739
return draw (
722
740
hps .fixed_dictionaries (dict (
723
741
labels = labels , logits = logits )).map (Kwargs ))
@@ -965,7 +983,9 @@ def _not_implemented(*args, **kwargs):
965
983
# keywords=None,
966
984
# defaults=(False, True, None))
967
985
TestCase (
968
- 'linalg.svd' , [single_arrays (shape = shapes (min_dims = 2 ))],
986
+ 'linalg.svd' , [single_arrays (
987
+ shape = shapes (min_dims = 2 ),
988
+ elements = floats (min_value = - 1e10 , max_value = 1e10 ))],
969
989
post_processor = _svd_post_process ),
970
990
TestCase (
971
991
'linalg.qr' , [
@@ -1177,8 +1197,11 @@ def _not_implemented(*args, **kwargs):
1177
1197
xla_const_args = (1 , 2 , 3 )),
1178
1198
TestCase (
1179
1199
'math.cumsum' , [
1180
- hps .tuples (array_axis_tuples (), hps .booleans (),
1181
- hps .booleans ()).map (lambda x : x [0 ] + (x [1 ], x [2 ]))
1200
+ hps .tuples (
1201
+ array_axis_tuples (
1202
+ elements = floats (min_value = - 1e12 , max_value = 1e12 )),
1203
+ hps .booleans (),
1204
+ hps .booleans ()).map (lambda x : x [0 ] + (x [1 ], x [2 ]))
1182
1205
],
1183
1206
xla_const_args = (1 , 2 , 3 )),
1184
1207
]
@@ -1222,7 +1245,8 @@ def _not_implemented(*args, **kwargs):
1222
1245
TestCase ('math.cos' , [single_arrays ()]),
1223
1246
TestCase ('math.cosh' , [single_arrays (elements = floats (- 100. , 100. ))]),
1224
1247
TestCase ('math.digamma' ,
1225
- [single_arrays (elements = non_zero_floats (- 1e4 , 1e4 ))]),
1248
+ [single_arrays (elements = non_zero_floats (- 1e4 , 1e4 ))],
1249
+ rtol = 5e-5 ),
1226
1250
TestCase ('math.erf' , [single_arrays ()]),
1227
1251
TestCase ('math.erfc' , [single_arrays ()]),
1228
1252
TestCase ('math.erfinv' , [single_arrays (elements = floats (- 1. , 1. ))]),
@@ -1274,7 +1298,10 @@ def _not_implemented(*args, **kwargs):
1274
1298
TestCase ('math.divide_no_nan' , [n_same_shape (n = 2 )]),
1275
1299
TestCase ('math.equal' , [n_same_shape (n = 2 )]),
1276
1300
TestCase ('math.floordiv' ,
1277
- [n_same_shape (n = 2 , elements = [floats (), non_zero_floats ()])]),
1301
+ # Clip numerator above zero to avoid NP/TF discrepancy in rounding
1302
+ # negative subnormal floats.
1303
+ [n_same_shape (
1304
+ n = 2 , elements = [positive_floats (), non_zero_floats ()])]),
1278
1305
TestCase ('math.floormod' ,
1279
1306
[n_same_shape (n = 2 , elements = [floats (), non_zero_floats ()])]),
1280
1307
TestCase ('math.greater' , [n_same_shape (n = 2 )]),
0 commit comments