You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[JAX] Explicitly cast large integer constants to uint32.
We intend to add `jax.jit` decorators around a number of functions in the JAX standard library, including operators such as `&`.
A consequence of this is that JAX will attempt to cast Python integers (n.b. not NumPy scalars) to signed int32 or int64 types depending on the JAX x64 mode. The constant 2**32 - 1 is out of range of int32 (absent -x64 mode) and will produce an error under the new semantics. Instead, use a Numpy uint32 scalar.
PiperOrigin-RevId: 388950774
0 commit comments