@@ -219,6 +219,7 @@ def f(key):
219
219
from oryx .core import primitive
220
220
from oryx .core .interpreters import harvest
221
221
from oryx .core .interpreters import log_prob as lp
222
+ from oryx .core .ppl import plate_util
222
223
223
224
__all__ = [
224
225
'block' ,
@@ -245,7 +246,10 @@ def f(key):
245
246
246
247
247
248
@functools .singledispatch
248
- def random_variable (obj , * , name : Optional [str ] = None ) -> Program :
249
+ def random_variable (obj ,
250
+ * ,
251
+ name : Optional [str ] = None ,
252
+ plate : Optional [str ] = None ) -> Program : # pylint: disable=redefined-outer-name
249
253
"""A single-dispatch function used to tag values and the outputs of programs.
250
254
251
255
`random_variable` is a single-dispatch function that enables registering
@@ -255,16 +259,67 @@ def random_variable(obj, *, name: Optional[str] = None) -> Program:
255
259
Args:
256
260
obj: A JAX type to be tagged.
257
261
name (str): A string name to tag input value, cannot be `None`.
262
+ plate (str): A string named axis for this random variable's plate.
258
263
259
264
Returns:
260
265
The input value.
261
266
"""
262
267
if name is None :
263
268
raise ValueError (f'Cannot call `random_variable` on { type (obj )} '
264
269
'without passing in a name.' )
270
+ if plate is not None :
271
+ raise ValueError (f'Cannot call `random_variable` on { type (obj )} '
272
+ 'with a plate.' )
265
273
return harvest .sow (obj , tag = RANDOM_VARIABLE , name = name , mode = 'strict' )
266
274
267
275
276
+ def plate (f : Optional [Program ] = None , name : Optional [str ] = None ):
277
+ """Transforms a program into one that draws samples on a named axis.
278
+
279
+ In graphical model parlance, a plate designates independent random variables.
280
+ The `plate` transformation follows this idea, where a `plate`-ed program
281
+ draws independent samples. Unlike `jax.vmap`-ing a program, which also
282
+ produces independent samples with positional batch dimensions, `plate`
283
+ produces samples with implicit named axes. Named axis support is useful for
284
+ other JAX transformations like `pmap` and `xmap`.
285
+
286
+ Specifically, a `plate`-ed program creates a different key for each axis
287
+ of the named axis. `log_prob` reduces over the named axis to produce a single
288
+ value.
289
+
290
+ Example usage:
291
+ ```python
292
+ @ppl.plate(name='foo')
293
+ def model(key):
294
+ return random_variable(random.normal)(key)
295
+ # We can't call model directly because there are implicit named axes present
296
+ try:
297
+ model(random.PRNGKey(0))
298
+ except NameError:
299
+ print('No named axis present!')
300
+ # If we vmap with a named axis, we produce independent samples.
301
+ vmap(model, axis_name='foo')(random.split(random.PRNGKey(0), 3)) #
302
+ ```
303
+
304
+ Args:
305
+ f: a `Program` to transform. If `f` is `None`, `plate` returns a decorator.
306
+ name: a `str` name for the plate which can used as a name axis in JAX
307
+ functions and transformations.
308
+
309
+ Returns:
310
+ A decorator if `f` is `None` or a transformed program if `f` is provided.
311
+ The transformed program behaves produces independent across a named
312
+ axis with name `name`.
313
+ """
314
+
315
+ def transform (f : Program ) -> Program :
316
+ return plate_util .make_plate (f , name = name )
317
+
318
+ if f is not None :
319
+ return transform (f )
320
+ return transform
321
+
322
+
268
323
# Alias for random_variable
269
324
rv = random_variable
270
325
@@ -273,21 +328,26 @@ def random_variable(obj, *, name: Optional[str] = None) -> Program:
273
328
@random_variable .register (functools .partial )
274
329
def function_random_variable (f : Program ,
275
330
* ,
276
- name : Optional [str ] = None ) -> Program :
331
+ name : Optional [str ] = None ,
332
+ plate : Optional [str ] = None ) -> Program : # pylint: disable=redefined-outer-name
277
333
"""Registers functions with the `random_variable` single dispatch function.
278
334
279
335
Args:
280
336
f: A probabilistic program.
281
337
name (str): A string name that is used to when tagging the output of `f`.
338
+ plate (str): A string named axis for this random variable's plate.
282
339
283
340
Returns:
284
341
A probabilistic program whose output is tagged with `name`.
285
342
"""
286
343
287
344
def wrapped (* args , ** kwargs ):
345
+ fun = f
346
+ if plate is not None :
347
+ fun = plate_util .make_plate (fun , name = plate )
288
348
if name is not None :
289
- return random_variable (nest (f , scope = name )(* args , ** kwargs ), name = name )
290
- return f (* args , ** kwargs )
349
+ return random_variable (nest (fun , scope = name )(* args , ** kwargs ), name = name )
350
+ return fun (* args , ** kwargs )
291
351
292
352
return wrapped
293
353
0 commit comments