Skip to content

Commit 8ed93b4

Browse files
sharadmvtensorflower-gardener
authored andcommitted
[Oryx] Enable vmap rule for tfd.Distributions
Also bumps minor version PiperOrigin-RevId: 381133393
1 parent adb87b0 commit 8ed93b4

File tree

4 files changed

+63
-4
lines changed

4 files changed

+63
-4
lines changed

spinoffs/oryx/oryx/distributions/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ py_library(
3838
deps = [
3939
# jax dep,
4040
"//oryx/core:primitive",
41+
"//oryx/core:trace_util",
4142
"//oryx/core/interpreters:harvest",
4243
"//oryx/core/interpreters:log_prob",
4344
"//oryx/core/interpreters:unzip",

spinoffs/oryx/oryx/distributions/distribution_extensions.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,17 @@
1414
# ============================================================================
1515
# Lint as: python3
1616
"""Wraps TFP distributions for use with Jax."""
17+
import itertools as it
18+
1719
from typing import Optional
1820

21+
import jax
1922
from jax import tree_util
2023
from jax import util as jax_util
24+
from jax.interpreters import batching
2125
from oryx.core import ppl
2226
from oryx.core import primitive
27+
from oryx.core import trace_util
2328
from oryx.core.interpreters import harvest
2429
from oryx.core.interpreters import inverse
2530
from oryx.core.interpreters import log_prob
@@ -36,7 +41,7 @@
3641

3742

3843
def random_variable_log_prob_rule(flat_incells, flat_outcells, *, num_consts,
39-
in_tree, out_tree, **_):
44+
in_tree, out_tree, batch_ndims, **_):
4045
"""Registers Oryx distributions with the log_prob transformation."""
4146
_, incells = jax_util.split_list(flat_incells, [num_consts])
4247
val_incells = incells[1:]
@@ -48,13 +53,37 @@ def random_variable_log_prob_rule(flat_incells, flat_outcells, *, num_consts,
4853
flat_outvals = [cell.val for cell in flat_outcells]
4954
_, dist = tree_util.tree_unflatten(in_tree, seed_flat_invals)
5055
outval = tree_util.tree_unflatten(out_tree, flat_outvals)
51-
return flat_incells, flat_outcells, dist.log_prob(outval)
56+
return flat_incells, flat_outcells, dist.log_prob(outval).sum(
57+
axis=list(range(batch_ndims)))
5258

5359
log_prob.log_prob_rules[random_variable_p] = random_variable_log_prob_rule
5460

5561
log_prob.log_prob_registry.add(random_variable_p)
5662

5763

64+
def random_variable_batching_rule(args, dims, *, num_consts, batch_ndims,
65+
jaxpr, **params):
66+
"""Batching (vmap) rule for the `random_variable` primitive."""
67+
old_consts = args[:num_consts]
68+
args, dims = args[num_consts:], dims[num_consts:]
69+
def _run(*args):
70+
return random_variable_p.impl(*it.chain(old_consts, args),
71+
num_consts=len(old_consts),
72+
jaxpr=jaxpr,
73+
batch_ndims=batch_ndims,
74+
**params)
75+
run = jax.vmap(_run, in_axes=dims, out_axes=0)
76+
closed_jaxpr, _ = trace_util.stage(run, dynamic=True)(*args)
77+
new_jaxpr, new_consts = closed_jaxpr.jaxpr, closed_jaxpr.literals
78+
result = random_variable_p.bind(*it.chain(new_consts, args),
79+
num_consts=len(new_consts),
80+
jaxpr=new_jaxpr,
81+
batch_ndims=batch_ndims + 1,
82+
**params)
83+
return result, (0,) * len(result)
84+
batching.primitive_batchers[random_variable_p] = random_variable_batching_rule
85+
86+
5887
def _sample_distribution(key, dist):
5988
return dist.sample(seed=key)
6089

@@ -66,10 +95,15 @@ def distribution_random_variable(dist: tfd.Distribution, *,
6695
"""Converts a distribution into a sampling function."""
6796
if plate is not None:
6897
dist = tfed.Sharded(dist, plate)
98+
if dist.batch_shape != []: # pylint: disable=g-explicit-bool-comparison
99+
raise ValueError(
100+
f'Cannot use a distribution with `batch_shape`: {dist.batch_shape}. '
101+
'Instead, use `jax.vmap` or `ppl.plate` to draw independent samples.')
69102
def wrapped(key):
70103
def sample(key):
71104
result = primitive.initial_style_bind(
72105
random_variable_p,
106+
batch_ndims=0,
73107
distribution_name=dist.__class__.__name__)(_sample_distribution)(
74108
key, dist)
75109
return result

spinoffs/oryx/oryx/distributions/distribution_extensions_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,30 @@ def test_plate_reduces_over_named_axes(self):
252252
np.testing.assert_allclose(
253253
tfd.Normal(0., 1.).log_prob(jnp.arange(3.)).sum(), out)
254254

255+
def test_vmapping_distribution_reduces_to_scalar_log_prob(self):
256+
257+
def model(key):
258+
return jax.vmap(ppl.rv(tfd.Normal(0., 1.)))(random.split(key))
259+
260+
out = ppl.log_prob(model)(jnp.arange(2.))
261+
np.testing.assert_allclose(
262+
tfd.Normal(0., 1.).log_prob(jnp.arange(2.)).sum(), out)
263+
264+
def test_can_map_over_batches_with_vmap_and_reduce_to_scalar_log_prob(self):
265+
266+
def f(key, x):
267+
return ppl.rv(tfd.Normal(x, 1.))(key)
268+
269+
def model(key, xs):
270+
return jax.vmap(f)(random.split(key), xs)
271+
272+
out = ppl.log_prob(model)(jnp.arange(2.), 2 * jnp.arange(2.))
273+
np.testing.assert_allclose(
274+
tfd.Normal(jnp.arange(2.), 1.).log_prob(2 * jnp.arange(2.)).sum(), out)
275+
276+
def test_cannot_use_distribution_with_nontrivial_batch_shape(self):
277+
with self.assertRaises(ValueError):
278+
ppl.rv(tfd.Normal(jnp.ones(2), 1.))(random.PRNGKey(0))
255279

256280
if __name__ == '__main__':
257281
absltest.main()

spinoffs/oryx/oryx/version.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
# We follow Semantic Versioning (https://semver.org/)
1818
_MAJOR_VERSION = '0'
19-
_MINOR_VERSION = '1'
20-
_PATCH_VERSION = '4'
19+
_MINOR_VERSION = '2'
20+
_PATCH_VERSION = '0'
2121

2222
# When building releases, we can update this value on the release branch to
2323
# reflect the current release candidate ('rc0', 'rc1') or, finally, the official

0 commit comments

Comments
 (0)