Skip to content

Commit 6efcda9

Browse files
vanderplastensorflower-gardener
authored andcommitted
Migrate TFP to support JAX typed PRNG keys
The context is described more fully in [JEP 9263](jax-ml/jax#17297). If you have comments on the JEP, we'd love to hear them! PiperOrigin-RevId: 566664782
1 parent a204ec8 commit 6efcda9

File tree

5 files changed

+34
-20
lines changed

5 files changed

+34
-20
lines changed

discussion/adaptive_malt/adaptive_malt.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def adaptive_mcmc_step(
350350
target_log_prob_fn: fun_mc.PotentialFn,
351351
num_mala_steps: int,
352352
num_adaptation_steps: int,
353-
seed: jax.random.KeyArray,
353+
seed: jax.Array,
354354
method: str = 'hmc',
355355
damping: Optional[jnp.ndarray] = None,
356356
scalar_step_size: Optional[jnp.ndarray] = None,
@@ -778,7 +778,7 @@ def adaptive_nuts_step(
778778
target_log_prob_fn: fun_mc.PotentialFn,
779779
num_mala_steps: int,
780780
num_adaptation_steps: int,
781-
seed: jax.random.KeyArray,
781+
seed: jax.Array,
782782
scalar_step_size: Optional[jnp.ndarray] = None,
783783
vector_step_size: Optional[jnp.ndarray] = None,
784784
rvar_factor: int = 8,
@@ -1040,7 +1040,7 @@ class MeadsExtra(NamedTuple):
10401040

10411041

10421042
def meads_init(state: jnp.ndarray, target_log_prob_fn: fun_mc.PotentialFn,
1043-
num_folds: int, seed: jax.random.KeyArray):
1043+
num_folds: int, seed: jax.Array):
10441044
"""Initializes MEADS."""
10451045
num_dimensions = state.shape[-1]
10461046
num_chains = state.shape[0]
@@ -1062,7 +1062,7 @@ def meads_init(state: jnp.ndarray, target_log_prob_fn: fun_mc.PotentialFn,
10621062

10631063
def meads_step(meads_state: MeadsState,
10641064
target_log_prob_fn: fun_mc.PotentialFn,
1065-
seed: jax.random.KeyArray,
1065+
seed: jax.Array,
10661066
vector_step_size: Optional[jnp.ndarray] = None,
10671067
damping: Optional[jnp.ndarray] = None,
10681068
step_size_multiplier: float = 0.5,
@@ -1221,7 +1221,7 @@ def run_adaptive_mcmc_on_target(
12211221
init_step_size: jnp.ndarray,
12221222
num_adaptation_steps: int,
12231223
num_results: int,
1224-
seed: jax.random.KeyArray,
1224+
seed: jax.Array,
12251225
num_mala_steps: int = 100,
12261226
rvar_smoothing: int = 0,
12271227
trajectory_opt_kwargs: Mapping[str, Any] = immutabledict.immutabledict({
@@ -1358,7 +1358,7 @@ def run_adaptive_nuts_on_target(
13581358
init_step_size: jnp.ndarray,
13591359
num_adaptation_steps: int,
13601360
num_results: int,
1361-
seed: jax.random.KeyArray,
1361+
seed: jax.Array,
13621362
num_mala_steps: int = 100,
13631363
rvar_smoothing: int = 0,
13641364
num_chains: Optional[int] = None,
@@ -1478,7 +1478,7 @@ def run_meads_on_target(
14781478
num_adaptation_steps: int,
14791479
num_results: int,
14801480
thinning: int,
1481-
seed: jax.random.KeyArray,
1481+
seed: jax.Array,
14821482
num_folds: int,
14831483
num_chains: Optional[int] = None,
14841484
init_x: Optional[jnp.ndarray] = None,
@@ -1596,7 +1596,7 @@ def run_fixed_mcmc_on_target(
15961596
target: gym.targets.Model,
15971597
init_x: jnp.ndarray,
15981598
method: str,
1599-
seed: jax.random.KeyArray,
1599+
seed: jax.Array,
16001600
num_warmup_steps: int,
16011601
num_results: int,
16021602
scalar_step_size: jnp.ndarray,
@@ -1706,7 +1706,7 @@ def run_vi_on_target(
17061706
init_x: jnp.ndarray,
17071707
num_steps: int,
17081708
learning_rate: float,
1709-
seed: jax.random.KeyArray,
1709+
seed: jax.Array,
17101710
):
17111711
"""Run VI on a target.
17121712

spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ def make_tensor_seed(seed):
9797
"""Converts a seed to a `Tensor` seed."""
9898
if seed is None:
9999
raise ValueError('seed must not be None when using JAX')
100-
if isinstance(seed, jax.random.PRNGKeyArray):
100+
if hasattr(seed, 'dtype') and jax.dtypes.issubdtype(
101+
seed.dtype, jax.dtypes.prng_key
102+
):
101103
return seed
102104
return jnp.asarray(seed, jnp.uint32)
103105

tensorflow_probability/python/internal/backend/numpy/ops.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,14 @@ def _default_convert_to_tensor(value, dtype=None):
218218
"""Default tensor conversion function for array, bool, int, float, and complex."""
219219
if JAX_MODE:
220220
# TODO(b/223267515): We shouldn't need to specialize here.
221-
if 'PRNGKeyArray' in str(type(value)):
221+
if hasattr(value, 'dtype') and jax.dtypes.issubdtype(
222+
value.dtype, jax.dtypes.prng_key
223+
):
222224
return value
223225
if isinstance(value, (list, tuple)) and value:
224-
if 'PRNGKeyArray' in str(type(value[0])):
226+
if hasattr(value[0], 'dtype') and jax.dtypes.issubdtype(
227+
value[0].dtype, jax.dtypes.prng_key
228+
):
225229
return np.stack(value, axis=0)
226230

227231
inferred_dtype = _infer_dtype(value, np.float32)

tensorflow_probability/python/internal/loop_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ def _convert_variables_to_tensors(values):
5252

5353
def tensor_array_from_element(elem, size=None, **kwargs):
5454
"""Construct a tf.TensorArray of elements with the dtype + shape of `elem`."""
55-
if JAX_MODE and isinstance(elem, jax.random.PRNGKeyArray):
56-
# If `trace_elt` is a `PRNGKeyArray`, then then it is not possible to create
55+
if JAX_MODE and jax.dtypes.issubdtype(elem.dtype, jax.dtypes.prng_key):
56+
# If `trace_elt` is a typed prng key, then then it is not possible to create
5757
# a matching (i.e., with the same custom PRNG) instance/array inside
5858
# `TensorArray.__init__` given just a `dtype`, `size`, and `shape`.
5959
#

tensorflow_probability/python/internal/test_util.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,12 @@ def evaluate(self, x):
163163
def _evaluate(x):
164164
if x is None:
165165
return x
166-
# TODO(b/223267515): Improve handling of JAX PRNGKeyArray objects.
167-
if JAX_MODE and isinstance(x, jax.random.PRNGKeyArray):
166+
# TODO(b/223267515): Improve handling of JAX typed PRNG keys.
167+
if (
168+
JAX_MODE
169+
and hasattr(x, 'dtype')
170+
and jax.dtypes.issubdtype(x.dtype, jax.dtypes.prng_key)
171+
):
168172
return x
169173
return np.array(x)
170174
return tf.nest.map_structure(_evaluate, x, expand_composites=True)
@@ -177,11 +181,15 @@ def _GetNdArray(self, a):
177181
def _evaluateTensors(self, a, b):
178182
if JAX_MODE:
179183
import jax # pylint: disable=g-import-not-at-top
180-
# HACK: In assertions (like self.assertAllClose), convert PRNGKeyArrays
181-
# to "normal" arrays so they can be compared with our existing machinery.
182-
if isinstance(a, jax.random.PRNGKeyArray):
184+
# HACK: In assertions (like self.assertAllClose), convert typed PRNG keys
185+
# to raw arrays so they can be compared with our existing machinery.
186+
if hasattr(a, 'dtype') and jax.dtypes.issubdtype(
187+
a.dtype, jax.dtypes.prng_key
188+
):
183189
a = jax.random.key_data(a)
184-
if isinstance(b, jax.random.PRNGKeyArray):
190+
if hasattr(b, 'dtype') and jax.dtypes.issubdtype(
191+
b.dtype, jax.dtypes.prng_key
192+
):
185193
b = jax.random.key_data(b)
186194
if tf.is_tensor(a) and tf.is_tensor(b):
187195
(a, b) = self.evaluate([a, b])

0 commit comments

Comments
 (0)