Skip to content

Commit cbf2a7f

Browse files
sharadmvtensorflower-gardener
authored andcommitted
[Oryx] Update propagate transformation to support state at each equation
PiperOrigin-RevId: 380879404
1 parent cd7b373 commit cbf2a7f

File tree

8 files changed

+254
-231
lines changed

8 files changed

+254
-231
lines changed

spinoffs/oryx/oryx/core/interpreters/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ py_library(
5353
srcs = ["propagate.py"],
5454
srcs_version = "PY3",
5555
deps = [
56+
":harvest",
5657
# dataclasses dep,
5758
# jax dep,
5859
"//oryx/core:pytree",

spinoffs/oryx/oryx/core/interpreters/inverse/core.py

Lines changed: 16 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,13 @@
1818
import jax
1919
from jax import abstract_arrays
2020
from jax import core as jax_core
21-
from jax import linear_util as lu
2221
from jax import tree_util
2322
from jax import util as jax_util
24-
from jax.interpreters import partial_eval as pe
2523
from jax.interpreters import pxla
26-
from jax.interpreters import xla
2724
import jax.numpy as np
2825

2926
from oryx.core import primitive
3027
from oryx.core import trace_util
31-
from oryx.core.interpreters import harvest
3228
from oryx.core.interpreters import propagate
3329
from oryx.core.interpreters.inverse import slice as slc
3430

@@ -177,8 +173,8 @@ def wrapped(*args, **kwargs):
177173
for arg in flat_forward_args]
178174
flat_incells = [InverseAndILDJ.unknown(aval) for aval in flat_forward_avals]
179175
flat_outcells = safe_map(InverseAndILDJ.new, flat_args)
180-
env = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr.jaxpr,
181-
flat_constcells, flat_incells, flat_outcells) # pytype: disable=wrong-arg-types
176+
env, _ = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr.jaxpr,
177+
flat_constcells, flat_incells, flat_outcells) # pytype: disable=wrong-arg-types
182178
flat_incells = [env.read(invar) for invar in jaxpr.jaxpr.invars]
183179
if any(not flat_incell.top() for flat_incell in flat_incells):
184180
raise ValueError('Cannot invert function.')
@@ -246,6 +242,9 @@ def __getitem__(self, prim):
246242
def __setitem__(self, prim, val):
247243
self.rules[prim] = val
248244

245+
def __contains__(self, prim):
246+
return prim in self.rules
247+
249248

250249
def register_elementwise(prim):
251250
"""Registers an elementwise primitive with ILDJ."""
@@ -296,46 +295,19 @@ def ildj_rule(incells, outcells, **params):
296295
ildj_registry = InverseDict()
297296

298297

299-
@lu.transformation_with_aux
300-
def flat_propagate(tree, *flat_invals):
301-
invals, outvals = tree_util.tree_unflatten(tree, flat_invals)
302-
subenv = yield ((invals, outvals), {})
303-
subenv_vals, subenv_tree = tree_util.tree_flatten(subenv)
304-
yield subenv_vals, subenv_tree
305-
306-
307-
def call_ildj(prim, incells, outcells, **params):
308-
"""InverseAndILDJ rule for call primitives."""
309-
f, incells = incells[0], incells[1:]
310-
flat_vals, in_tree = tree_util.tree_flatten((incells, outcells))
311-
new_params = dict(params)
312-
if 'donated_invars' in params:
313-
new_params['donated_invars'] = (False,) * len(flat_vals)
314-
f, aux = flat_propagate(f, in_tree)
315-
subenv_vals = prim.bind(f, *flat_vals, **new_params)
316-
subenv_tree = aux()
317-
subenv = tree_util.tree_unflatten(subenv_tree, subenv_vals)
318-
new_incells = [subenv.read(var) for var in subenv.jaxpr.invars]
319-
new_outcells = [subenv.read(var) for var in subenv.jaxpr.outvars]
320-
return new_incells, new_outcells, subenv
321-
ildj_registry[xla.xla_call_p] = jax_util.partial(call_ildj, xla.xla_call_p)
322-
ildj_registry[jax_core.call_p] = jax_util.partial(call_ildj, jax_core.call_p)
323-
ildj_registry[pe.remat_call_p] = jax_util.partial(call_ildj, pe.remat_call_p)
324-
ildj_registry[harvest.nest_p] = jax_util.partial(call_ildj, harvest.nest_p)
325-
326-
327298
def hop_inverse_rule(prim):
328-
ildj_registry[prim] = jax_util.partial(call_ildj, prim)
299+
ildj_registry[prim] = jax_util.partial(propagate.call_rule, prim)
329300
primitive.register_hop_transformation_rule('inverse', hop_inverse_rule)
330301

331302

332303
def initial_ildj(incells, outcells, *, jaxpr, num_consts, **_):
333304
const_cells, incells = jax_util.split_list(incells, [num_consts])
334-
env = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr, const_cells,
335-
incells, outcells) # pytype: disable=wrong-arg-types
305+
env, state = propagate.propagate(
306+
InverseAndILDJ, ildj_registry, jaxpr, const_cells,
307+
incells, outcells) # pytype: disable=wrong-arg-types
336308
new_incells = [env.read(invar) for invar in jaxpr.invars]
337309
new_outcells = [env.read(outvar) for outvar in jaxpr.outvars]
338-
return const_cells + new_incells, new_outcells, None
310+
return const_cells + new_incells, new_outcells, state
339311

340312

341313
def initial_inverse_rule(prim):
@@ -371,7 +343,7 @@ def remove_slice(cell):
371343
mapped_incells = safe_map(remove_slice, incells)
372344
mapped_outcells = safe_map(remove_slice, outcells)
373345
flat_vals, in_tree = tree_util.tree_flatten((mapped_incells, mapped_outcells))
374-
f, aux = flat_propagate(f, in_tree)
346+
f, aux = propagate.flat_propagate(f, in_tree)
375347
# Assume all invars as mapped
376348
new_in_axes = (0,) * len(flat_vals)
377349
new_params = dict(params, in_axes=new_in_axes)
@@ -383,14 +355,13 @@ def remove_slice(cell):
383355
lambda: (0,) * aux().num_leaves,
384356
closure=('ildj', params['out_axes']))
385357
del new_params['out_axes']
386-
subenv_vals = prim.bind(f, *flat_vals, **new_params)
387-
subenv_tree = aux()
388-
subenv = tree_util.tree_unflatten(subenv_tree, subenv_vals)
389-
new_incells = [subenv.read(var) for var in subenv.jaxpr.invars]
390-
new_outcells = [subenv.read(var) for var in subenv.jaxpr.outvars]
358+
flat_out = prim.bind(f, *flat_vals, **new_params)
359+
out_tree = aux()
360+
new_incells, new_outcells, state = tree_util.tree_unflatten(
361+
out_tree, flat_out)
391362
new_incells = [add_slice(v, old_v)
392363
for old_v, v in safe_zip(incells, new_incells)]
393364
new_outcells = [add_slice(v, old_v)
394365
for old_v, v in safe_zip(outcells, new_outcells)]
395-
return new_incells, new_outcells, subenv
366+
return new_incells, new_outcells, state
396367
ildj_registry[pxla.xla_pmap_p] = jax_util.partial(map_ildj, pxla.xla_pmap_p)

spinoffs/oryx/oryx/core/interpreters/log_prob.py

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from jax import core as jax_core
1818
from jax import random
1919
from jax import tree_util
20-
import jax.numpy as np
2120

2221
from oryx.core import trace_util
2322
from oryx.core.interpreters import inverse
@@ -47,16 +46,17 @@ def __missing__(self, prim):
4746
self[prim] = rule = make_default_rule(prim)
4847
return rule
4948

50-
log_prob_rules = LogProbRules()
5149

50+
log_prob_rules = LogProbRules()
5251

5352
# The log_prob_registry is used to compute log_prob values from samples after
5453
# propagation is done.
55-
log_prob_registry = {}
54+
log_prob_registry = set()
5655

5756

5857
def log_prob(f):
5958
"""LogProb function transformation."""
59+
6060
def wrapped(sample, *args, **kwargs):
6161
"""Function wrapper that takes in log_prob arguments."""
6262
# Trace the function using a random seed
@@ -69,50 +69,55 @@ def wrapped(sample, *args, **kwargs):
6969
InverseAndILDJ.unknown(trace_util.get_shaped_aval(dummy_seed))
7070
] + [InverseAndILDJ.new(val) for val in flat_inargs]
7171
flat_outcells = [InverseAndILDJ.new(a) for a in flat_outargs]
72-
# Re-use the InverseAndILDJ propagation but silently fail instead of
73-
# erroring when we hit a primitive we can't invert.
74-
env = propagate.propagate(InverseAndILDJ, log_prob_rules, jaxpr.jaxpr,
75-
constcells, flat_incells, flat_outcells)
76-
# Traverse the resulting environment, looking for primitives that have
77-
# registered log_probs.
78-
final_log_prob = _accumulate_log_probs(env)
79-
return final_log_prob
80-
return wrapped
72+
return log_prob_jaxpr(jaxpr.jaxpr, constcells, flat_incells, flat_outcells)
8173

74+
return wrapped
8275

83-
def _accumulate_log_probs(env):
84-
"""Recursively traverses Jaxprs to accumulate log_prob values."""
85-
final_log_prob = 0.0
86-
eqns = safe_map(propagate.Equation.from_jaxpr_eqn, env.jaxpr.eqns)
87-
for eqn in eqns:
88-
if eqn.primitive in log_prob_registry:
89-
var, = eqn.outvars
90-
if var not in env:
91-
raise ValueError('Cannot compute log_prob of function.')
92-
incells = [env.read(v) for v in eqn.invars]
93-
outcells = [env.read(v) for v in eqn.outvars]
94-
outcell, = outcells
95-
if not outcell.top():
96-
raise ValueError('Cannot compute log_prob of function.')
97-
lp = log_prob_registry[eqn.primitive](
98-
[cell if not cell.top() else cell.val for cell in incells],
99-
outcell.val, **eqn.params
100-
)
101-
assert np.ndim(lp) == 0, 'log_prob must return a scalar.'
102-
# Accumulate ILDJ term
103-
final_log_prob += lp + np.sum(outcell.ildj)
104-
for subenv in env.subenvs.values():
105-
sub_lp = _accumulate_log_probs(subenv)
106-
final_log_prob += sub_lp
76+
failed_log_prob = object() # sentinel for being unable to compute a log_prob
77+
78+
79+
def log_prob_jaxpr(jaxpr, constcells, flat_incells, flat_outcells):
80+
"""Runs log_prob propagation on a Jaxpr."""
81+
82+
def reducer(env, eqn, curr_log_prob, new_log_prob):
83+
if curr_log_prob is failed_log_prob or new_log_prob is failed_log_prob:
84+
# If `curr_log_prob` is `None` that means we were unable to compute
85+
# a log_prob elsewhere, so the propagate failed.
86+
return failed_log_prob
87+
if eqn.primitive in log_prob_registry and new_log_prob is None:
88+
# We are unable to compute a log_prob for this primitive.
89+
return failed_log_prob
90+
if new_log_prob is not None:
91+
cells = [env.read(var) for var in eqn.outvars]
92+
ildjs = sum([cell.ildj.sum() for cell in cells if cell.top()])
93+
return curr_log_prob + new_log_prob + ildjs
94+
return curr_log_prob
95+
96+
# Re-use the InverseAndILDJ propagation but silently fail instead of
97+
# erroring when we hit a primitive we can't invert. We accumulate the log
98+
# probability values using the propagater state.
99+
_, final_log_prob = propagate.propagate(
100+
InverseAndILDJ,
101+
log_prob_rules,
102+
jaxpr,
103+
constcells,
104+
flat_incells,
105+
flat_outcells,
106+
reducer=reducer,
107+
initial_state=0.)
108+
if final_log_prob is failed_log_prob:
109+
raise ValueError('Cannot compute log_prob of function.')
107110
return final_log_prob
108111

109112

110113
def make_default_rule(prim):
111114
"""Creates rule for prim without a registered log_prob."""
115+
112116
def rule(incells, outcells, **params):
113117
"""Executes the inverse rule but fails if the inverse isn't implemented."""
114118
try:
115119
return ildj_registry[prim](incells, outcells, **params)
116120
except NotImplementedError:
117121
return incells, outcells, None
122+
118123
return rule

spinoffs/oryx/oryx/core/interpreters/log_prob_test.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from oryx.core import state
2929
from oryx.core.interpreters.log_prob import log_prob
3030
from oryx.core.interpreters.log_prob import log_prob_registry
31+
from oryx.core.interpreters.log_prob import log_prob_rules
3132
from oryx.internal import test_util
3233

3334
random_normal_p = jax_core.Primitive('random_normal')
@@ -49,10 +50,14 @@ def random_normal_abstract(_, name=None):
4950
random_normal_p.def_abstract_eval(random_normal_abstract)
5051

5152

52-
def random_normal_log_prob(_, outval, name=None):
53-
del name
54-
return bd.Normal(0., 1.).log_prob(outval)
55-
log_prob_registry[random_normal_p] = random_normal_log_prob
53+
def random_normal_log_prob_rule(incells, outcells, **_):
54+
outcell, = outcells
55+
if not outcell.top():
56+
return incells, outcells, None
57+
outval = outcell.val
58+
return incells, outcells, bd.Normal(0., 1.).log_prob(outval)
59+
log_prob_rules[random_normal_p] = random_normal_log_prob_rule
60+
log_prob_registry.add(random_normal_p)
5661

5762

5863
def call(f):

0 commit comments

Comments
 (0)