Skip to content

Commit 2750050

Browse files
hawkinsptensorflower-gardener
authored andcommitted
[JAX] Explicitly cast large integer constants to uint32.
We intend to add `jax.jit` decorators around a number of functions in the JAX standard library, including operators such as `&`. A consequence of this is that JAX will attempt to cast Python integers (n.b. not NumPy scalars) to signed int32 or int64 types depending on the JAX x64 mode. The constant 2**32 - 1 is out of range of int32 (absent -x64 mode) and will produce an error under the new semantics. Instead, use a Numpy uint32 scalar. PiperOrigin-RevId: 388950774
1 parent f7eaefe commit 2750050

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tensorflow_probability/python/internal/samplers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,8 @@ def fold_in(seed, salt):
145145
if JAX_MODE:
146146
from jax import random as jaxrand # pylint: disable=g-import-not-at-top
147147
import jax.numpy as jnp # pylint: disable=g-import-not-at-top
148-
return jaxrand.fold_in(seed,
149-
jnp.asarray(salt & (2**32 - 1), dtype=SEED_DTYPE))
148+
return jaxrand.fold_in(
149+
seed, jnp.asarray(salt & np.uint32(2**32 - 1), dtype=SEED_DTYPE))
150150
if isinstance(salt, (six.integer_types)):
151151
seed = tf.bitwise.bitwise_xor(
152152
seed, np.uint64([salt & (2**64 - 1)]).view(np.int32))

0 commit comments

Comments
 (0)