@@ -185,6 +185,12 @@ def _astuple(x):
185
185
return x
186
186
187
187
188
+ def _default_index_type ():
189
+ if JAX_MODE and not jax .config .read ('jax_enable_x64' ):
190
+ return np .int32
191
+ return np .int64
192
+
193
+
188
194
def _bincount (arr , weights = None , minlength = None , maxlength = None , # pylint: disable=unused-argument
189
195
dtype = np .int32 , name = None ): # pylint: disable=unused-argument
190
196
"""Counts number of occurences of each value in `arr`."""
@@ -467,16 +473,16 @@ def _unsorted_segment_sum(data, segment_ids, num_segments, name=None):
467
473
468
474
argmax = utils .copy_docstring (
469
475
'tf.math.argmax' ,
470
- lambda input , axis = None , output_type = np . int64 , name = None : ( # pylint: disable=g-long-lambda
476
+ lambda input , axis = None , output_type = None , name = None : ( # pylint: disable=g-long-lambda
471
477
np .argmax (input , axis = 0 if axis is None else int (axis ))
472
- .astype (utils .numpy_dtype (output_type ))))
478
+ .astype (utils .numpy_dtype (output_type or _default_index_type () ))))
473
479
474
480
argmin = utils .copy_docstring (
475
481
'tf.math.argmin' ,
476
- lambda input , axis = None , output_type = np . int64 , name = None : ( # pylint: disable=g-long-lambda
482
+ lambda input , axis = None , output_type = None , name = None : ( # pylint: disable=g-long-lambda
477
483
np .argmin (_convert_to_tensor (
478
484
input ), axis = 0 if axis is None else int (axis ))
479
- .astype (utils .numpy_dtype (output_type ))))
485
+ .astype (utils .numpy_dtype (output_type or _default_index_type () ))))
480
486
481
487
asin = utils .copy_docstring (
482
488
'tf.math.asin' ,
@@ -542,8 +548,9 @@ def _unsorted_segment_sum(data, segment_ids, num_segments, name=None):
542
548
543
549
count_nonzero = utils .copy_docstring (
544
550
'tf.math.count_nonzero' ,
545
- lambda input , axis = None , keepdims = None , dtype = np .int64 , name = None : ( # pylint: disable=g-long-lambda
546
- utils .numpy_dtype (dtype )(np .count_nonzero (input , axis ))))
551
+ lambda input , axis = None , keepdims = None , dtype = None , name = None : ( # pylint: disable=g-long-lambda
552
+ utils .numpy_dtype (dtype or _default_index_type ())(
553
+ np .count_nonzero (input , axis ))))
547
554
548
555
cumprod = utils .copy_docstring (
549
556
'tf.math.cumprod' ,
0 commit comments