18
18
import jax
19
19
from jax import abstract_arrays
20
20
from jax import core as jax_core
21
- from jax import linear_util as lu
22
21
from jax import tree_util
23
22
from jax import util as jax_util
24
- from jax .interpreters import partial_eval as pe
25
23
from jax .interpreters import pxla
26
- from jax .interpreters import xla
27
24
import jax .numpy as np
28
25
29
26
from oryx .core import primitive
30
27
from oryx .core import trace_util
31
- from oryx .core .interpreters import harvest
32
28
from oryx .core .interpreters import propagate
33
29
from oryx .core .interpreters .inverse import slice as slc
34
30
@@ -177,8 +173,8 @@ def wrapped(*args, **kwargs):
177
173
for arg in flat_forward_args ]
178
174
flat_incells = [InverseAndILDJ .unknown (aval ) for aval in flat_forward_avals ]
179
175
flat_outcells = safe_map (InverseAndILDJ .new , flat_args )
180
- env = propagate .propagate (InverseAndILDJ , ildj_registry , jaxpr .jaxpr ,
181
- flat_constcells , flat_incells , flat_outcells ) # pytype: disable=wrong-arg-types
176
+ env , _ = propagate .propagate (InverseAndILDJ , ildj_registry , jaxpr .jaxpr ,
177
+ flat_constcells , flat_incells , flat_outcells ) # pytype: disable=wrong-arg-types
182
178
flat_incells = [env .read (invar ) for invar in jaxpr .jaxpr .invars ]
183
179
if any (not flat_incell .top () for flat_incell in flat_incells ):
184
180
raise ValueError ('Cannot invert function.' )
@@ -246,6 +242,9 @@ def __getitem__(self, prim):
246
242
def __setitem__ (self , prim , val ):
247
243
self .rules [prim ] = val
248
244
245
+ def __contains__ (self , prim ):
246
+ return prim in self .rules
247
+
249
248
250
249
def register_elementwise (prim ):
251
250
"""Registers an elementwise primitive with ILDJ."""
@@ -296,46 +295,19 @@ def ildj_rule(incells, outcells, **params):
296
295
ildj_registry = InverseDict ()
297
296
298
297
299
- @lu .transformation_with_aux
300
- def flat_propagate (tree , * flat_invals ):
301
- invals , outvals = tree_util .tree_unflatten (tree , flat_invals )
302
- subenv = yield ((invals , outvals ), {})
303
- subenv_vals , subenv_tree = tree_util .tree_flatten (subenv )
304
- yield subenv_vals , subenv_tree
305
-
306
-
307
- def call_ildj (prim , incells , outcells , ** params ):
308
- """InverseAndILDJ rule for call primitives."""
309
- f , incells = incells [0 ], incells [1 :]
310
- flat_vals , in_tree = tree_util .tree_flatten ((incells , outcells ))
311
- new_params = dict (params )
312
- if 'donated_invars' in params :
313
- new_params ['donated_invars' ] = (False ,) * len (flat_vals )
314
- f , aux = flat_propagate (f , in_tree )
315
- subenv_vals = prim .bind (f , * flat_vals , ** new_params )
316
- subenv_tree = aux ()
317
- subenv = tree_util .tree_unflatten (subenv_tree , subenv_vals )
318
- new_incells = [subenv .read (var ) for var in subenv .jaxpr .invars ]
319
- new_outcells = [subenv .read (var ) for var in subenv .jaxpr .outvars ]
320
- return new_incells , new_outcells , subenv
321
- ildj_registry [xla .xla_call_p ] = jax_util .partial (call_ildj , xla .xla_call_p )
322
- ildj_registry [jax_core .call_p ] = jax_util .partial (call_ildj , jax_core .call_p )
323
- ildj_registry [pe .remat_call_p ] = jax_util .partial (call_ildj , pe .remat_call_p )
324
- ildj_registry [harvest .nest_p ] = jax_util .partial (call_ildj , harvest .nest_p )
325
-
326
-
327
298
def hop_inverse_rule (prim ):
328
- ildj_registry [prim ] = jax_util .partial (call_ildj , prim )
299
+ ildj_registry [prim ] = jax_util .partial (propagate . call_rule , prim )
329
300
primitive .register_hop_transformation_rule ('inverse' , hop_inverse_rule )
330
301
331
302
332
303
def initial_ildj (incells , outcells , * , jaxpr , num_consts , ** _ ):
333
304
const_cells , incells = jax_util .split_list (incells , [num_consts ])
334
- env = propagate .propagate (InverseAndILDJ , ildj_registry , jaxpr , const_cells ,
335
- incells , outcells ) # pytype: disable=wrong-arg-types
305
+ env , state = propagate .propagate (
306
+ InverseAndILDJ , ildj_registry , jaxpr , const_cells ,
307
+ incells , outcells ) # pytype: disable=wrong-arg-types
336
308
new_incells = [env .read (invar ) for invar in jaxpr .invars ]
337
309
new_outcells = [env .read (outvar ) for outvar in jaxpr .outvars ]
338
- return const_cells + new_incells , new_outcells , None
310
+ return const_cells + new_incells , new_outcells , state
339
311
340
312
341
313
def initial_inverse_rule (prim ):
@@ -371,7 +343,7 @@ def remove_slice(cell):
371
343
mapped_incells = safe_map (remove_slice , incells )
372
344
mapped_outcells = safe_map (remove_slice , outcells )
373
345
flat_vals , in_tree = tree_util .tree_flatten ((mapped_incells , mapped_outcells ))
374
- f , aux = flat_propagate (f , in_tree )
346
+ f , aux = propagate . flat_propagate (f , in_tree )
375
347
# Assume all invars as mapped
376
348
new_in_axes = (0 ,) * len (flat_vals )
377
349
new_params = dict (params , in_axes = new_in_axes )
@@ -383,14 +355,13 @@ def remove_slice(cell):
383
355
lambda : (0 ,) * aux ().num_leaves ,
384
356
closure = ('ildj' , params ['out_axes' ]))
385
357
del new_params ['out_axes' ]
386
- subenv_vals = prim .bind (f , * flat_vals , ** new_params )
387
- subenv_tree = aux ()
388
- subenv = tree_util .tree_unflatten (subenv_tree , subenv_vals )
389
- new_incells = [subenv .read (var ) for var in subenv .jaxpr .invars ]
390
- new_outcells = [subenv .read (var ) for var in subenv .jaxpr .outvars ]
358
+ flat_out = prim .bind (f , * flat_vals , ** new_params )
359
+ out_tree = aux ()
360
+ new_incells , new_outcells , state = tree_util .tree_unflatten (
361
+ out_tree , flat_out )
391
362
new_incells = [add_slice (v , old_v )
392
363
for old_v , v in safe_zip (incells , new_incells )]
393
364
new_outcells = [add_slice (v , old_v )
394
365
for old_v , v in safe_zip (outcells , new_outcells )]
395
- return new_incells , new_outcells , subenv
366
+ return new_incells , new_outcells , state
396
367
ildj_registry [pxla .xla_pmap_p ] = jax_util .partial (map_ildj , pxla .xla_pmap_p )
0 commit comments