Skip to content

Commit 69d88a0

Browse files
sharadmvtensorflower-gardener
authored andcommitted
[Oryx] Add plate transformation for using named axes in Oryx that parallels
the design of the `Sharded` distribution in TFP. Also simplifies the `HigherOrderPrimitive`. PiperOrigin-RevId: 380879501
1 parent cbf2a7f commit 69d88a0

File tree

9 files changed

+409
-135
lines changed

9 files changed

+409
-135
lines changed

spinoffs/oryx/oryx/core/ppl/BUILD

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ py_library(
2525
srcs = ["transformations.py"],
2626
srcs_version = "PY3",
2727
deps = [
28+
":plate_util",
2829
# jax dep,
2930
"//oryx/core:primitive",
3031
"//oryx/core/interpreters:harvest",
@@ -37,8 +38,19 @@ py_library(
3738
name = "ppl",
3839
srcs = ["__init__.py"],
3940
srcs_version = "PY3",
41+
deps = [":transformations"],
42+
)
43+
44+
# pytype_strict
45+
py_library(
46+
name = "plate_util",
47+
srcs = ["plate_util.py"],
48+
srcs_version = "PY3",
4049
deps = [
41-
":transformations",
50+
# jax dep,
51+
"//oryx/core:primitive",
52+
"//oryx/core/interpreters:log_prob",
53+
"//oryx/core/interpreters:propagate",
4254
],
4355
)
4456

@@ -52,7 +64,9 @@ py_test(
5264
":transformations",
5365
# absl/testing:absltest dep,
5466
# jax dep,
67+
# numpy dep,
5568
"//oryx/core/interpreters:log_prob",
5669
"//oryx/internal:test_util",
70+
# tensorflow_probability/substrates:jax dep,
5771
],
5872
)

spinoffs/oryx/oryx/core/ppl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from oryx.core.ppl.transformations import log_prob
2424
from oryx.core.ppl.transformations import LogProbFunction
2525
from oryx.core.ppl.transformations import nest
26+
from oryx.core.ppl.transformations import plate
2627
from oryx.core.ppl.transformations import Program
2728
from oryx.core.ppl.transformations import random_variable
2829
from oryx.core.ppl.transformations import RANDOM_VARIABLE
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2021 The TensorFlow Probability Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
"""Contains utilities for the `plate` transformation.
16+
17+
A plate is a term in graphical models that is used to designate independent
18+
random variables. In Oryx, `plate` is a transformation that converts a program
19+
into one that produces independent samples. Ordinarily, this can be done with
20+
`jax.vmap`, where we could split several random keys and map a program over
21+
them. Unlike `jax.vmap`, `plate` operates using named axes. A `plate`-ed
22+
program will specialize the random seed to the particular index of the axis
23+
being mapped over. Taking the `log_prob` of a `plate` program will reduce over
24+
the named axis. In design, `plate` resembles the `Sharded` meta-distribution
25+
from TensorFlow Probability.
26+
27+
In implementation, `plate` is an Oryx `HigherOrderPrimitive` (i.e. a JAX
28+
`CallPrimitive` with a `log_prob` rule that reduces over a named axis at the
29+
end.
30+
"""
31+
import functools
32+
33+
from jax import lax
34+
from jax import random
35+
36+
from oryx.core import primitive
37+
from oryx.core.interpreters import log_prob
38+
from oryx.core.interpreters import propagate
39+
40+
41+
__all__ = [
42+
'make_plate',
43+
]
44+
45+
46+
plate_p = primitive.HigherOrderPrimitive('plate')
47+
48+
49+
def plate_log_prob_rule(incells, outcells, *, plate, **params):
50+
incells, outcells, lp = propagate.call_rule(
51+
plate_p, incells, outcells, plate=plate, **params)
52+
return incells, outcells, lax.psum(lp, plate)
53+
54+
55+
log_prob.log_prob_rules[plate_p] = plate_log_prob_rule
56+
log_prob.log_prob_registry.add(plate_p)
57+
58+
59+
def make_plate(f, *, name):
60+
"""Wraps a probabilistic program in a plate with a named axis."""
61+
62+
@functools.wraps(f)
63+
def plate_fun(key, *args, **kwargs):
64+
key = random.fold_in(key, lax.axis_index(name))
65+
return f(key, *args, **kwargs)
66+
67+
def wrapped(key, *args, **kwargs):
68+
return primitive.call_bind(
69+
plate_p, plate=name)(plate_fun)(key, *args, **kwargs)
70+
71+
return wrapped

spinoffs/oryx/oryx/core/ppl/transformations.py

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ def f(key):
219219
from oryx.core import primitive
220220
from oryx.core.interpreters import harvest
221221
from oryx.core.interpreters import log_prob as lp
222+
from oryx.core.ppl import plate_util
222223

223224
__all__ = [
224225
'block',
@@ -245,7 +246,10 @@ def f(key):
245246

246247

247248
@functools.singledispatch
248-
def random_variable(obj, *, name: Optional[str] = None) -> Program:
249+
def random_variable(obj,
250+
*,
251+
name: Optional[str] = None,
252+
plate: Optional[str] = None) -> Program: # pylint: disable=redefined-outer-name
249253
"""A single-dispatch function used to tag values and the outputs of programs.
250254
251255
`random_variable` is a single-dispatch function that enables registering
@@ -255,16 +259,67 @@ def random_variable(obj, *, name: Optional[str] = None) -> Program:
255259
Args:
256260
obj: A JAX type to be tagged.
257261
name (str): A string name to tag input value, cannot be `None`.
262+
plate (str): A string named axis for this random variable's plate.
258263
259264
Returns:
260265
The input value.
261266
"""
262267
if name is None:
263268
raise ValueError(f'Cannot call `random_variable` on {type(obj)} '
264269
'without passing in a name.')
270+
if plate is not None:
271+
raise ValueError(f'Cannot call `random_variable` on {type(obj)} '
272+
'with a plate.')
265273
return harvest.sow(obj, tag=RANDOM_VARIABLE, name=name, mode='strict')
266274

267275

276+
def plate(f: Optional[Program] = None, name: Optional[str] = None):
277+
"""Transforms a program into one that draws samples on a named axis.
278+
279+
In graphical model parlance, a plate designates independent random variables.
280+
The `plate` transformation follows this idea, where a `plate`-ed program
281+
draws independent samples. Unlike `jax.vmap`-ing a program, which also
282+
produces independent samples with positional batch dimensions, `plate`
283+
produces samples with implicit named axes. Named axis support is useful for
284+
other JAX transformations like `pmap` and `xmap`.
285+
286+
Specifically, a `plate`-ed program creates a different key for each axis
287+
of the named axis. `log_prob` reduces over the named axis to produce a single
288+
value.
289+
290+
Example usage:
291+
```python
292+
@ppl.plate(name='foo')
293+
def model(key):
294+
return random_variable(random.normal)(key)
295+
# We can't call model directly because there are implicit named axes present
296+
try:
297+
model(random.PRNGKey(0))
298+
except NameError:
299+
print('No named axis present!')
300+
# If we vmap with a named axis, we produce independent samples.
301+
vmap(model, axis_name='foo')(random.split(random.PRNGKey(0), 3)) #
302+
```
303+
304+
Args:
305+
f: a `Program` to transform. If `f` is `None`, `plate` returns a decorator.
306+
name: a `str` name for the plate which can used as a name axis in JAX
307+
functions and transformations.
308+
309+
Returns:
310+
A decorator if `f` is `None` or a transformed program if `f` is provided.
311+
The transformed program behaves produces independent across a named
312+
axis with name `name`.
313+
"""
314+
315+
def transform(f: Program) -> Program:
316+
return plate_util.make_plate(f, name=name)
317+
318+
if f is not None:
319+
return transform(f)
320+
return transform
321+
322+
268323
# Alias for random_variable
269324
rv = random_variable
270325

@@ -273,21 +328,26 @@ def random_variable(obj, *, name: Optional[str] = None) -> Program:
273328
@random_variable.register(functools.partial)
274329
def function_random_variable(f: Program,
275330
*,
276-
name: Optional[str] = None) -> Program:
331+
name: Optional[str] = None,
332+
plate: Optional[str] = None) -> Program: # pylint: disable=redefined-outer-name
277333
"""Registers functions with the `random_variable` single dispatch function.
278334
279335
Args:
280336
f: A probabilistic program.
281337
name (str): A string name that is used to when tagging the output of `f`.
338+
plate (str): A string named axis for this random variable's plate.
282339
283340
Returns:
284341
A probabilistic program whose output is tagged with `name`.
285342
"""
286343

287344
def wrapped(*args, **kwargs):
345+
fun = f
346+
if plate is not None:
347+
fun = plate_util.make_plate(fun, name=plate)
288348
if name is not None:
289-
return random_variable(nest(f, scope=name)(*args, **kwargs), name=name)
290-
return f(*args, **kwargs)
349+
return random_variable(nest(fun, scope=name)(*args, **kwargs), name=name)
350+
return fun(*args, **kwargs)
291351

292352
return wrapped
293353

0 commit comments

Comments
 (0)