Skip to content

Commit c1728d3

Browse files
csutertensorflower-gardener
authored andcommitted
Avoid deprecated casting of size-1 np.ndarrays.
This used to be allowed but is now deprecated. Some logic that lies downstream of many of our distributions' log_prob methods would invoke this behavior (in a try/except, so it would not fail even post-deprecation, but we get an annoyting warning all the time). This change avoids that deprecated behavior. PiperOrigin-RevId: 574008432
1 parent e6907a1 commit c1728d3

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

tensorflow_probability/python/distributions/internal/statistical_testing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@
127127
import tensorflow.compat.v2 as tf
128128
from tensorflow_probability.python.internal import distribution_util
129129
from tensorflow_probability.python.internal import dtype_util
130+
from tensorflow_probability.python.internal import prefer_static as ps
130131
from tensorflow_probability.python.internal import tensorshape_util
131132
from tensorflow_probability.python.util.seed_stream import SeedStream
132133

@@ -1494,7 +1495,7 @@ def _random_unit_hypersphere(sample_shape, event_shape, dtype, seed):
14941495
target_shape = tf.concat([sample_shape, event_shape], axis=0)
14951496
return tf.math.l2_normalize(
14961497
tf.random.normal(target_shape, seed=seed, dtype=dtype),
1497-
axis=-1 - tf.range(tf.size(event_shape)))
1498+
axis=-1 - ps.range(ps.size(event_shape)))
14981499

14991500

15001501
def assert_multivariate_true_cdf_equal_on_projections_two_sample(

tensorflow_probability/python/internal/backend/numpy/numpy_math.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,15 @@
165165

166166
def _astuple(x):
167167
"""Attempt to convert the given argument to be a Python tuple."""
168-
try:
169-
return (int(x),)
170-
except TypeError:
171-
pass
168+
# Numpy used to allow casting a size-1 ndarray to python scalar literal types.
169+
# In version 1.25 this was deprecated, causing a warning to be issued in the
170+
# below try/except. To avoid that, we just fall through in the case of an
171+
# np.ndarray.
172+
if not isinstance(x, np.ndarray):
173+
try:
174+
return (int(x),)
175+
except TypeError:
176+
pass
172177

173178
try:
174179
return tuple(x)

0 commit comments

Comments
 (0)