Skip to content

Commit 68f626f

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
Disallow subnormal floats in numpy_test and unpin hypothesis version.
PiperOrigin-RevId: 451291890
1 parent e27dcc5 commit 68f626f

File tree

2 files changed

+37
-11
lines changed

2 files changed

+37
-11
lines changed

tensorflow_probability/python/internal/backend/numpy/numpy_test.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555

5656
ALLOW_NAN = False
5757
ALLOW_INFINITY = False
58+
ALLOW_SUBNORMAL = False
5859

5960
JAX_MODE = False
6061
NUMPY_MODE = not JAX_MODE
@@ -81,6 +82,12 @@ def _getattr(obj, name):
8182
return functools.reduce(getattr, names, obj)
8283

8384

85+
def _maybe_get_subnormal_kwarg(allow_subnormal=ALLOW_SUBNORMAL):
86+
if hp.__version_info__ >= (6, 30):
87+
return {'allow_subnormal': allow_subnormal}
88+
return {}
89+
90+
8491
class TestCase(dict):
8592
"""`dict` object containing test strategies for a single function."""
8693

@@ -121,18 +128,21 @@ def floats(draw,
121128
max_value=1e16,
122129
allow_nan=ALLOW_NAN,
123130
allow_infinity=ALLOW_INFINITY,
131+
allow_subnormal=ALLOW_SUBNORMAL,
124132
dtype=None):
125133
if dtype is None:
126134
dtype = np.float32 if FLAGS.use_tpu else np.float64
127135
if min_value is not None:
128136
min_value = onp.array(min_value, dtype=dtype).item()
129137
if max_value is not None:
130138
max_value = onp.array(max_value, dtype=dtype).item()
139+
subnormal_kwarg = _maybe_get_subnormal_kwarg(allow_subnormal)
131140
return draw(hps.floats(min_value=min_value,
132141
max_value=max_value,
133142
allow_nan=allow_nan,
134143
allow_infinity=allow_infinity,
135-
width=np.dtype(dtype).itemsize * 8))
144+
width=np.dtype(dtype).itemsize * 8,
145+
**subnormal_kwarg))
136146

137147

138148
def integers(min_value=-2**30, max_value=2**30):
@@ -604,11 +614,15 @@ def top_k_params(draw):
604614
def histogram_fixed_width_bins_params(draw):
605615
# TODO(b/187125431): the `min_side=2` and `unique` check can be removed if
606616
# https://github.com/tensorflow/tensorflow/pull/38899 is re-implemented.
617+
subnormal_kwarg = _maybe_get_subnormal_kwarg()
607618
values = draw(single_arrays(
608619
dtype=np.float32,
609620
shape=shapes(min_dims=1, min_side=2),
610621
unique=True,
611-
elements=hps.floats(min_value=-1e5, max_value=1e5, width=32)
622+
# Avoid intervals containing 0 due to NP/TF discrepancy for bin boundaries
623+
# near 0.
624+
elements=hps.floats(min_value=0., max_value=1e10, width=32,
625+
**subnormal_kwarg),
612626
))
613627
vmin, vmax = np.min(values), np.max(values)
614628
value_min = draw(hps.one_of(
@@ -699,10 +713,12 @@ def sparse_xent_params(draw):
699713
shape=hps.just(tuple()),
700714
dtype=np.int32,
701715
elements=hps.integers(0, num_classes - 1))
716+
subnormal_kwarg = _maybe_get_subnormal_kwarg()
702717
logits = single_arrays(
703718
batch_shape=batch_shape,
704719
shape=hps.just((num_classes,)),
705-
elements=hps.floats(min_value=-1e5, max_value=1e5, width=32))
720+
elements=hps.floats(min_value=-1e5, max_value=1e5, width=32,
721+
**subnormal_kwarg))
706722
return draw(
707723
hps.fixed_dictionaries(dict(
708724
labels=labels, logits=logits)).map(Kwargs))
@@ -714,10 +730,12 @@ def xent_params(draw):
714730
batch_shape = draw(shapes(min_dims=1))
715731
labels = batched_probabilities(
716732
batch_shape=batch_shape, num_classes=num_classes)
733+
subnormal_kwarg = _maybe_get_subnormal_kwarg()
717734
logits = single_arrays(
718735
batch_shape=batch_shape,
719736
shape=hps.just((num_classes,)),
720-
elements=hps.floats(min_value=-1e5, max_value=1e5, width=32))
737+
elements=hps.floats(min_value=-1e5, max_value=1e5, width=32,
738+
**subnormal_kwarg))
721739
return draw(
722740
hps.fixed_dictionaries(dict(
723741
labels=labels, logits=logits)).map(Kwargs))
@@ -965,7 +983,9 @@ def _not_implemented(*args, **kwargs):
965983
# keywords=None,
966984
# defaults=(False, True, None))
967985
TestCase(
968-
'linalg.svd', [single_arrays(shape=shapes(min_dims=2))],
986+
'linalg.svd', [single_arrays(
987+
shape=shapes(min_dims=2),
988+
elements=floats(min_value=-1e10, max_value=1e10))],
969989
post_processor=_svd_post_process),
970990
TestCase(
971991
'linalg.qr', [
@@ -1177,8 +1197,11 @@ def _not_implemented(*args, **kwargs):
11771197
xla_const_args=(1, 2, 3)),
11781198
TestCase(
11791199
'math.cumsum', [
1180-
hps.tuples(array_axis_tuples(), hps.booleans(),
1181-
hps.booleans()).map(lambda x: x[0] + (x[1], x[2]))
1200+
hps.tuples(
1201+
array_axis_tuples(
1202+
elements=floats(min_value=-1e12, max_value=1e12)),
1203+
hps.booleans(),
1204+
hps.booleans()).map(lambda x: x[0] + (x[1], x[2]))
11821205
],
11831206
xla_const_args=(1, 2, 3)),
11841207
]
@@ -1222,7 +1245,8 @@ def _not_implemented(*args, **kwargs):
12221245
TestCase('math.cos', [single_arrays()]),
12231246
TestCase('math.cosh', [single_arrays(elements=floats(-100., 100.))]),
12241247
TestCase('math.digamma',
1225-
[single_arrays(elements=non_zero_floats(-1e4, 1e4))]),
1248+
[single_arrays(elements=non_zero_floats(-1e4, 1e4))],
1249+
rtol=5e-5),
12261250
TestCase('math.erf', [single_arrays()]),
12271251
TestCase('math.erfc', [single_arrays()]),
12281252
TestCase('math.erfinv', [single_arrays(elements=floats(-1., 1.))]),
@@ -1274,7 +1298,10 @@ def _not_implemented(*args, **kwargs):
12741298
TestCase('math.divide_no_nan', [n_same_shape(n=2)]),
12751299
TestCase('math.equal', [n_same_shape(n=2)]),
12761300
TestCase('math.floordiv',
1277-
[n_same_shape(n=2, elements=[floats(), non_zero_floats()])]),
1301+
# Clip numerator above zero to avoid NP/TF discrepancy in rounding
1302+
# negative subnormal floats.
1303+
[n_same_shape(
1304+
n=2, elements=[positive_floats(), non_zero_floats()])]),
12781305
TestCase('math.floormod',
12791306
[n_same_shape(n=2, elements=[floats(), non_zero_floats()])]),
12801307
TestCase('math.greater', [n_same_shape(n=2)]),

testing/dependency_install_lib.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,7 @@ install_common_packages() {
8181
install_test_only_packages() {
8282
# The following unofficial dependencies are used only by tests.
8383
PIP_FLAGS=${1-}
84-
# TODO(b/233927309): Unpin hypothesis version.
85-
python -m pip install $PIP_FLAGS hypothesis==6.46.7 matplotlib mock mpmath scipy pandas optax
84+
python -m pip install $PIP_FLAGS hypothesis matplotlib mock mpmath scipy pandas optax
8685
}
8786

8887
dump_versions() {

0 commit comments

Comments
 (0)