@@ -34,6 +34,7 @@ def value_and_gradient(f,
34
34
output_gradients = None ,
35
35
use_gradient_tape = False ,
36
36
auto_unpack_single_arg = True ,
37
+ has_aux = False ,
37
38
name = None ,
38
39
** kwargs ):
39
40
"""Computes `f(*args, **kwargs)` and its gradients wrt to `args`, `kwargs`.
@@ -92,13 +93,20 @@ def value_and_gradient(f,
92
93
auto_unpack_single_arg: Python `bool` which when `False` means the single
93
94
arg case will not be interpreted as a list of arguments. (See case 2.)
94
95
Default value: `True`.
96
+ has_aux: Whether `f(*args, **kwargs)` actually returns two outputs, the
97
+ first being `y` and the second being an auxiliary output that does not get
98
+ gradients computed.
95
99
name: Python `str` name prefixed to ops created by this function.
96
100
Default value: `None` (i.e., `'value_and_gradient'`).
97
101
**kwargs: Named arguments as in `f(*args, **kwargs)` and basis for gradient.
98
102
99
103
Returns:
100
- y: `y = f(*args, **kwargs)`.
101
- dydx: Gradients of `y` with respect to each of `args` and `kwargs`.
104
+ If `has_aux` is `False`:
105
+ y: `y = f(*args, **kwargs)`.
106
+ dydx: Gradients of `y` with respect to each of `args` and `kwargs`.
107
+ otherwise:
108
+ A tuple `((y, aux), dydx)`, where `y, aux = f(*args, **kwargs)` and `dydx`
109
+ are the gradients of `y` with respect to each of `args` and `kwargs`.
102
110
"""
103
111
with tf .name_scope (name or 'value_and_gradient' ):
104
112
return _value_and_grad_impl (
@@ -109,6 +117,7 @@ def value_and_gradient(f,
109
117
output_gradients = output_gradients ,
110
118
auto_unpack_single_arg = auto_unpack_single_arg ,
111
119
expand_tf_modules_as_trainable_vars = False ,
120
+ has_aux = has_aux ,
112
121
** kwargs )
113
122
114
123
@@ -117,6 +126,7 @@ def value_and_gradient_with_auto_expansion(f,
117
126
output_gradients = None ,
118
127
use_gradient_tape = False ,
119
128
auto_unpack_single_arg = True ,
129
+ has_aux = False ,
120
130
name = None ,
121
131
** kwargs ):
122
132
"""Computes `f(*args, **kwargs)` and its gradients wrt to `args`, `kwargs`.
@@ -190,13 +200,20 @@ def value_and_gradient_with_auto_expansion(f,
190
200
auto_unpack_single_arg: Python `bool` which when `False` means the single
191
201
arg case will not be interpreted as a list of arguments. (See case 2.)
192
202
Default value: `True`.
203
+ has_aux: Whether `f(*args, **kwargs)` actually returns two outputs, the
204
+ first being `y` and the second being an auxiliary output that does not get
205
+ gradients computed.
193
206
name: Python `str` name prefixed to ops created by this function.
194
207
Default value: `None` (i.e., `'value_and_gradient'`).
195
208
**kwargs: Named arguments as in `f(*args, **kwargs)` and basis for gradient.
196
209
197
210
Returns:
198
- y: `y = f(*args, **kwargs)`.
199
- dydx: Gradients of `y` with respect to each of `args` and `kwargs`.
211
+ If `has_aux` is `False`:
212
+ y: `y = f(*args, **kwargs)`.
213
+ dydx: Gradients of `y` with respect to each of `args` and `kwargs`.
214
+ otherwise:
215
+ A tuple `((y, aux), dydx)`, where `y, aux = f(*args, **kwargs)` and `dydx`
216
+ are the gradients of `y` with respect to each of `args` and `kwargs`.
200
217
"""
201
218
with tf .name_scope (name or 'value_and_gradient' ):
202
219
return _value_and_grad_impl (
@@ -207,12 +224,14 @@ def value_and_gradient_with_auto_expansion(f,
207
224
output_gradients = output_gradients ,
208
225
auto_unpack_single_arg = auto_unpack_single_arg ,
209
226
expand_tf_modules_as_trainable_vars = True ,
227
+ has_aux = has_aux ,
210
228
** kwargs )
211
229
212
230
213
231
def value_and_batch_jacobian (f ,
214
232
* args ,
215
233
auto_unpack_single_arg = True ,
234
+ has_aux = False ,
216
235
name = None ,
217
236
** kwargs ):
218
237
"""Computes `f(*args, **kwargs)` and batch Jacobian wrt to `args`, `kwargs`.
@@ -225,15 +244,23 @@ def value_and_batch_jacobian(f,
225
244
auto_unpack_single_arg: Python `bool` which when `False` means the single
226
245
arg case will not be interpreted as a list of arguments.
227
246
Default value: `True`.
247
+ has_aux: Whether `f(*args, **kwargs)` actually returns two outputs, the
248
+ first being `y` and the second being an auxiliary output that does not get
249
+ gradients computed.
228
250
name: Python `str` name prefixed to ops created by this function.
229
251
Default value: `None` (i.e., `'value_and_gradient'`).
230
252
**kwargs: Named arguments as in `f(*args, **kwargs)` and basis for Jacobian.
231
253
Each element must be 2D `(batch, n)`-shaped argument `Tensor`(s). If
232
254
multiple are provided, a tuple of jacobians are returned.
233
255
234
256
Returns:
235
- y: `y = f(*args, **kwargs)`.
236
- jacobian: A `(batch, n, n)` shaped `Tensor`, `dy/dx`, or a tuple thereof.
257
+ If `has_aux` is `False`:
258
+ y: `y = f(*args, **kwargs)`.
259
+ jacobian: A `(batch, n, n)` shaped `Tensor`, `dy/dx`, or a tuple thereof.
260
+ otherwise:
261
+ A tuple `((y, aux), jacobian)`, where `y, aux = f(*args, **kwargs)` and
262
+ `jacobian` is a `(batch, n, n)` shaped `Tensor`, `dy/dx`, or a tuple
263
+ thereof.
237
264
"""
238
265
with tf .name_scope (name or 'value_and_batch_jacobian' ):
239
266
return _value_and_grad_impl (
@@ -243,12 +270,14 @@ def value_and_batch_jacobian(f,
243
270
output_gradients = None ,
244
271
auto_unpack_single_arg = auto_unpack_single_arg ,
245
272
expand_tf_modules_as_trainable_vars = False ,
273
+ has_aux = has_aux ,
246
274
** kwargs )
247
275
248
276
249
277
def batch_jacobian (f ,
250
278
* args ,
251
279
auto_unpack_single_arg = True ,
280
+ has_aux = False ,
252
281
name = None ,
253
282
** kwargs ):
254
283
"""Computes batch Jacobian of `f(*args, **kwargs)` wrt to `args`, `kwargs`.
@@ -261,53 +290,68 @@ def batch_jacobian(f,
261
290
auto_unpack_single_arg: Python `bool` which when `False` means the single
262
291
arg case will not be interpreted as a list of arguments.
263
292
Default value: `True`.
293
+ has_aux: Whether `f(*args, **kwargs)` actually returns two outputs, the
294
+ first being `y` and the second being an auxiliary output that does not get
295
+ gradients computed.
264
296
name: Python `str` name prefixed to ops created by this function.
265
297
Default value: `None` (i.e., `'value_and_gradient'`).
266
298
**kwargs: Named arguments as in `f(*args, **kwargs)` and basis for Jacobian.
267
299
Each element must be 2D `(batch, n)`-shaped argument `Tensor`(s). If
268
300
multiple are provided, a tuple of jacobians are returned.
269
301
270
302
Returns:
271
- jacobian: A `(batch, n, n)` shaped `Tensor`, `dy/dx`, or a tuple thereof.
303
+ If `has_aux` is `False`:
304
+ jacobian: A `(batch, n, n)` shaped `Tensor`, `dy/dx`, or a tuple thereof.
305
+ otherwise:
306
+ jacobian: A `(batch, n, n)` shaped `Tensor`, `dy/dx`, or a tuple thereof.
307
+ aux: The auxiliary output of the function `y, aux = f(*args, **kwargs)`.
272
308
"""
273
- return value_and_batch_jacobian (
309
+ res = value_and_batch_jacobian (
274
310
f ,
275
311
* args ,
276
312
auto_unpack_single_arg = auto_unpack_single_arg ,
277
313
name = name ,
278
- ** kwargs )[1 ]
314
+ has_aux = has_aux ,
315
+ ** kwargs )
316
+ if has_aux :
317
+ (_ , aux ), jacobian = res
318
+ return jacobian , aux
319
+ else :
320
+ _ , jacobian = res
321
+ return jacobian
279
322
280
323
281
324
def _gradient_new (f , xs , grad_ys ):
282
325
with tf .GradientTape (watch_accessed_variables = False ) as tape :
283
326
for x in xs :
284
327
tape .watch (x )
285
- y = f ()
286
- return y , tape .gradient (y , xs , output_gradients = grad_ys )
328
+ y , aux = f ()
329
+ return y , tape .gradient (y , xs , output_gradients = grad_ys ), aux
287
330
288
331
289
332
def _gradient_old (f , xs , grad_ys ):
290
333
assert not tf .executing_eagerly ()
291
- y = f ()
292
- return y , tf .gradients (y , xs , grad_ys = grad_ys )
334
+ y , aux = f ()
335
+ return y , tf .gradients (y , xs , grad_ys = grad_ys ), aux
293
336
294
337
295
338
def _jacobian (f , xs , grad_ys ):
296
339
assert grad_ys is None
297
340
with tf .GradientTape (persistent = True , watch_accessed_variables = False ) as tape :
298
341
for x in xs :
299
342
tape .watch (x )
300
- y = f ()
343
+ y , aux = f ()
301
344
try :
302
- return y , tuple (tape .batch_jacobian (y , x ) for x in xs )
345
+ return y , tuple (tape .batch_jacobian (y , x ) for x in xs ), aux
303
346
except ValueError : # Fallback to for-loop jacobian.
304
347
return y , tuple (tape .batch_jacobian (y , x , experimental_use_pfor = False )
305
- for x in xs )
348
+ for x in xs ), aux
306
349
307
350
308
351
def _value_and_grad_impl (f , grad_fn , * args , output_gradients ,
309
352
auto_unpack_single_arg ,
310
353
expand_tf_modules_as_trainable_vars = False ,
354
+ has_aux = False ,
311
355
** kwargs ):
312
356
"""Helper which generalizes gradient / Jacobian."""
313
357
if not args and not kwargs :
@@ -329,18 +373,31 @@ def _value_and_grad_impl(f, grad_fn, *args, output_gradients,
329
373
[args , kwargs ])
330
374
else :
331
375
expand_args , expand_kwargs = args , kwargs
332
- y , dydx = grad_fn (lambda : f (* args , ** kwargs ) if _has_args (f ) else f (),
333
- tf .nest .flatten ([expand_args , expand_kwargs ]),
334
- output_gradients )
376
+
377
+ if not has_aux :
378
+ real_f = f
379
+ f = lambda * args , ** kwargs : (real_f (* args , ** kwargs ) # pylint: disable=g-long-lambda
380
+ if _has_args (real_f ) else real_f (), ())
381
+
382
+ y , dydx , aux = grad_fn (lambda : f (* args , ** kwargs ) if _has_args (f ) else f (),
383
+ tf .nest .flatten ([expand_args , expand_kwargs ]),
384
+ output_gradients )
335
385
dydx_args , dydx_kwargs = tf .nest .pack_sequence_as (
336
386
[expand_args , expand_kwargs ], dydx )
337
387
if len (args ) == 1 and not do_unpack :
338
388
dydx_args = dydx_args [0 ]
339
- if not kwargs :
340
- return y , dydx_args
341
- if not args :
342
- return y , dydx_kwargs
343
- return y , dydx_args , dydx_kwargs
389
+
390
+ if has_aux :
391
+ res = ((y , aux ),)
392
+ else :
393
+ res = (y ,)
394
+
395
+ if args :
396
+ res += (dydx_args ,)
397
+ if kwargs :
398
+ res += (dydx_kwargs ,)
399
+
400
+ return res
344
401
345
402
346
403
def _prepare_args (args , kwargs ):
@@ -380,8 +437,9 @@ def value_and_gradient(f, # pylint: disable=function-redefined
380
437
* args ,
381
438
output_gradients = None ,
382
439
use_gradient_tape = False , # pylint: disable=unused-argument
383
- name = None , # pylint: disable=unused-argument
384
440
auto_unpack_single_arg = True ,
441
+ has_aux = False ,
442
+ name = None , # pylint: disable=unused-argument
385
443
** kwargs ):
386
444
"""Computes `f(*args)` and its gradients wrt to `*args`."""
387
445
if kwargs :
@@ -392,16 +450,27 @@ def value_and_gradient(f, # pylint: disable=function-redefined
392
450
if do_unpack :
393
451
args = args [0 ]
394
452
args , _ = _prepare_args (args , {})
395
- y , f_vjp = jax .vjp (f , * args )
453
+ if has_aux :
454
+ y , f_vjp , aux = jax .vjp (f , * args , has_aux = True )
455
+ else :
456
+ y , f_vjp = jax .vjp (f , * args )
396
457
if output_gradients is None :
397
458
output_gradients = tf .nest .map_structure (np .ones_like , y )
398
459
dydx = list (f_vjp (output_gradients ))
399
460
if len (args ) == 1 and not do_unpack :
400
461
dydx = dydx [0 ]
401
- return y , dydx
462
+ if has_aux :
463
+ return (y , aux ), dydx
464
+ else :
465
+ return y , dydx
402
466
403
467
def value_and_batch_jacobian ( # pylint: disable=function-redefined
404
- f , * args , auto_unpack_single_arg = True , name = None , ** kwargs ): # pylint: disable=unused-argument
468
+ f ,
469
+ * args ,
470
+ auto_unpack_single_arg = True ,
471
+ has_aux = False ,
472
+ name = None , # pylint: disable=unused-argument
473
+ ** kwargs ):
405
474
"""JAX implementation of value_and_batch_jacobian."""
406
475
if kwargs :
407
476
raise NotImplementedError ('Jax version of `value_and_batch_jacobian` '
@@ -411,7 +480,10 @@ def value_and_batch_jacobian( # pylint: disable=function-redefined
411
480
if do_unpack :
412
481
args = args [0 ]
413
482
args , _ = _prepare_args (args , {})
414
- y , f_vjp = jax .vjp (f , * args )
483
+ if has_aux :
484
+ y , f_vjp , aux = jax .vjp (f , * args , has_aux = True )
485
+ else :
486
+ y , f_vjp = jax .vjp (f , * args )
415
487
416
488
# Let `[B, E_1, ..., E_k]` be the shape of `y`, where the first dimension
417
489
# is a batch dimension. We construct a basis for the cotangent space
@@ -426,13 +498,28 @@ def value_and_batch_jacobian( # pylint: disable=function-redefined
426
498
dydx = [x .reshape (y .shape + x .shape [2 :]) for x in dydx ]
427
499
if len (args ) == 1 and not do_unpack :
428
500
dydx = dydx [0 ]
429
- return y , dydx
501
+ if has_aux :
502
+ return (y , aux ), dydx
503
+ else :
504
+ return y , dydx
430
505
431
506
def batch_jacobian ( # pylint: disable=function-redefined
432
- f , * args , auto_unpack_single_arg = True , name = None , ** kwargs ): # pylint: disable=unused-argument
507
+ f ,
508
+ * args ,
509
+ auto_unpack_single_arg = True ,
510
+ has_aux = False ,
511
+ name = None ,
512
+ ** kwargs ): # pylint: disable=unused-argument
433
513
"""Computes the batch jacobian of `f(xs)` w.r.t. `xs`."""
434
- return value_and_batch_jacobian (
514
+ res = value_and_batch_jacobian (
435
515
f ,
436
516
* args ,
437
517
auto_unpack_single_arg = auto_unpack_single_arg ,
438
- ** kwargs )[1 ]
518
+ has_aux = has_aux ,
519
+ ** kwargs )
520
+ if has_aux :
521
+ (_ , aux ), jacobian = res
522
+ return jacobian , aux
523
+ else :
524
+ _ , jacobian = res
525
+ return jacobian
0 commit comments