14
14
# ============================================================================
15
15
# Lint as: python3
16
16
"""Wraps TFP distributions for use with Jax."""
17
+ import itertools as it
18
+
17
19
from typing import Optional
18
20
21
+ import jax
19
22
from jax import tree_util
20
23
from jax import util as jax_util
24
+ from jax .interpreters import batching
21
25
from oryx .core import ppl
22
26
from oryx .core import primitive
27
+ from oryx .core import trace_util
23
28
from oryx .core .interpreters import harvest
24
29
from oryx .core .interpreters import inverse
25
30
from oryx .core .interpreters import log_prob
36
41
37
42
38
43
def random_variable_log_prob_rule (flat_incells , flat_outcells , * , num_consts ,
39
- in_tree , out_tree , ** _ ):
44
+ in_tree , out_tree , batch_ndims , ** _ ):
40
45
"""Registers Oryx distributions with the log_prob transformation."""
41
46
_ , incells = jax_util .split_list (flat_incells , [num_consts ])
42
47
val_incells = incells [1 :]
@@ -48,13 +53,37 @@ def random_variable_log_prob_rule(flat_incells, flat_outcells, *, num_consts,
48
53
flat_outvals = [cell .val for cell in flat_outcells ]
49
54
_ , dist = tree_util .tree_unflatten (in_tree , seed_flat_invals )
50
55
outval = tree_util .tree_unflatten (out_tree , flat_outvals )
51
- return flat_incells , flat_outcells , dist .log_prob (outval )
56
+ return flat_incells , flat_outcells , dist .log_prob (outval ).sum (
57
+ axis = list (range (batch_ndims )))
52
58
53
59
log_prob .log_prob_rules [random_variable_p ] = random_variable_log_prob_rule
54
60
55
61
log_prob .log_prob_registry .add (random_variable_p )
56
62
57
63
64
+ def random_variable_batching_rule (args , dims , * , num_consts , batch_ndims ,
65
+ jaxpr , ** params ):
66
+ """Batching (vmap) rule for the `random_variable` primitive."""
67
+ old_consts = args [:num_consts ]
68
+ args , dims = args [num_consts :], dims [num_consts :]
69
+ def _run (* args ):
70
+ return random_variable_p .impl (* it .chain (old_consts , args ),
71
+ num_consts = len (old_consts ),
72
+ jaxpr = jaxpr ,
73
+ batch_ndims = batch_ndims ,
74
+ ** params )
75
+ run = jax .vmap (_run , in_axes = dims , out_axes = 0 )
76
+ closed_jaxpr , _ = trace_util .stage (run , dynamic = True )(* args )
77
+ new_jaxpr , new_consts = closed_jaxpr .jaxpr , closed_jaxpr .literals
78
+ result = random_variable_p .bind (* it .chain (new_consts , args ),
79
+ num_consts = len (new_consts ),
80
+ jaxpr = new_jaxpr ,
81
+ batch_ndims = batch_ndims + 1 ,
82
+ ** params )
83
+ return result , (0 ,) * len (result )
84
+ batching .primitive_batchers [random_variable_p ] = random_variable_batching_rule
85
+
86
+
58
87
def _sample_distribution (key , dist ):
59
88
return dist .sample (seed = key )
60
89
@@ -66,10 +95,15 @@ def distribution_random_variable(dist: tfd.Distribution, *,
66
95
"""Converts a distribution into a sampling function."""
67
96
if plate is not None :
68
97
dist = tfed .Sharded (dist , plate )
98
+ if dist .batch_shape != []: # pylint: disable=g-explicit-bool-comparison
99
+ raise ValueError (
100
+ f'Cannot use a distribution with `batch_shape`: { dist .batch_shape } . '
101
+ 'Instead, use `jax.vmap` or `ppl.plate` to draw independent samples.' )
69
102
def wrapped (key ):
70
103
def sample (key ):
71
104
result = primitive .initial_style_bind (
72
105
random_variable_p ,
106
+ batch_ndims = 0 ,
73
107
distribution_name = dist .__class__ .__name__ )(_sample_distribution )(
74
108
key , dist )
75
109
return result
0 commit comments