Skip to content

Commit 4c5c0f9

Browse files
jburnimtensorflower-gardener
authored andcommitted
Under JAX, avoid int64 in argmin/argmax/count_nonzero if x64 is not enabled.
PiperOrigin-RevId: 473045851
1 parent ea6efb6 commit 4c5c0f9

File tree

1 file changed

+13
-6
lines changed
  • tensorflow_probability/python/internal/backend/numpy

1 file changed

+13
-6
lines changed

tensorflow_probability/python/internal/backend/numpy/numpy_math.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,12 @@ def _astuple(x):
185185
return x
186186

187187

188+
def _default_index_type():
189+
if JAX_MODE and not jax.config.read('jax_enable_x64'):
190+
return np.int32
191+
return np.int64
192+
193+
188194
def _bincount(arr, weights=None, minlength=None, maxlength=None, # pylint: disable=unused-argument
189195
dtype=np.int32, name=None): # pylint: disable=unused-argument
190196
"""Counts number of occurences of each value in `arr`."""
@@ -467,16 +473,16 @@ def _unsorted_segment_sum(data, segment_ids, num_segments, name=None):
467473

468474
argmax = utils.copy_docstring(
469475
'tf.math.argmax',
470-
lambda input, axis=None, output_type=np.int64, name=None: ( # pylint: disable=g-long-lambda
476+
lambda input, axis=None, output_type=None, name=None: ( # pylint: disable=g-long-lambda
471477
np.argmax(input, axis=0 if axis is None else int(axis))
472-
.astype(utils.numpy_dtype(output_type))))
478+
.astype(utils.numpy_dtype(output_type or _default_index_type()))))
473479

474480
argmin = utils.copy_docstring(
475481
'tf.math.argmin',
476-
lambda input, axis=None, output_type=np.int64, name=None: ( # pylint: disable=g-long-lambda
482+
lambda input, axis=None, output_type=None, name=None: ( # pylint: disable=g-long-lambda
477483
np.argmin(_convert_to_tensor(
478484
input), axis=0 if axis is None else int(axis))
479-
.astype(utils.numpy_dtype(output_type))))
485+
.astype(utils.numpy_dtype(output_type or _default_index_type()))))
480486

481487
asin = utils.copy_docstring(
482488
'tf.math.asin',
@@ -542,8 +548,9 @@ def _unsorted_segment_sum(data, segment_ids, num_segments, name=None):
542548

543549
count_nonzero = utils.copy_docstring(
544550
'tf.math.count_nonzero',
545-
lambda input, axis=None, keepdims=None, dtype=np.int64, name=None: ( # pylint: disable=g-long-lambda
546-
utils.numpy_dtype(dtype)(np.count_nonzero(input, axis))))
551+
lambda input, axis=None, keepdims=None, dtype=None, name=None: ( # pylint: disable=g-long-lambda
552+
utils.numpy_dtype(dtype or _default_index_type())(
553+
np.count_nonzero(input, axis))))
547554

548555
cumprod = utils.copy_docstring(
549556
'tf.math.cumprod',

0 commit comments

Comments
 (0)