Skip to content

Commit a100002

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
Allow returning auxiliary information from tfp.math.value_and_gradient.
Previously this had to be done via returning tensors wrapped in stop_gradient, but this is more convenient and more powerful as it still lets you compute gradients flowing through the auxiliary return value. stop_gradiend would prevent such ability. PiperOrigin-RevId: 388329574
1 parent 5faa1ac commit a100002

File tree

2 files changed

+177
-33
lines changed

2 files changed

+177
-33
lines changed

tensorflow_probability/python/math/gradient.py

Lines changed: 120 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def value_and_gradient(f,
3434
output_gradients=None,
3535
use_gradient_tape=False,
3636
auto_unpack_single_arg=True,
37+
has_aux=False,
3738
name=None,
3839
**kwargs):
3940
"""Computes `f(*args, **kwargs)` and its gradients wrt to `args`, `kwargs`.
@@ -92,13 +93,20 @@ def value_and_gradient(f,
9293
auto_unpack_single_arg: Python `bool` which when `False` means the single
9394
arg case will not be interpreted as a list of arguments. (See case 2.)
9495
Default value: `True`.
96+
has_aux: Whether `f(*args, **kwargs)` actually returns two outputs, the
97+
first being `y` and the second being an auxiliary output that does not get
98+
gradients computed.
9599
name: Python `str` name prefixed to ops created by this function.
96100
Default value: `None` (i.e., `'value_and_gradient'`).
97101
**kwargs: Named arguments as in `f(*args, **kwargs)` and basis for gradient.
98102
99103
Returns:
100-
y: `y = f(*args, **kwargs)`.
101-
dydx: Gradients of `y` with respect to each of `args` and `kwargs`.
104+
If `has_aux` is `False`:
105+
y: `y = f(*args, **kwargs)`.
106+
dydx: Gradients of `y` with respect to each of `args` and `kwargs`.
107+
otherwise:
108+
A tuple `((y, aux), dydx)`, where `y, aux = f(*args, **kwargs)` and `dydx`
109+
are the gradients of `y` with respect to each of `args` and `kwargs`.
102110
"""
103111
with tf.name_scope(name or 'value_and_gradient'):
104112
return _value_and_grad_impl(
@@ -109,6 +117,7 @@ def value_and_gradient(f,
109117
output_gradients=output_gradients,
110118
auto_unpack_single_arg=auto_unpack_single_arg,
111119
expand_tf_modules_as_trainable_vars=False,
120+
has_aux=has_aux,
112121
**kwargs)
113122

114123

@@ -117,6 +126,7 @@ def value_and_gradient_with_auto_expansion(f,
117126
output_gradients=None,
118127
use_gradient_tape=False,
119128
auto_unpack_single_arg=True,
129+
has_aux=False,
120130
name=None,
121131
**kwargs):
122132
"""Computes `f(*args, **kwargs)` and its gradients wrt to `args`, `kwargs`.
@@ -190,13 +200,20 @@ def value_and_gradient_with_auto_expansion(f,
190200
auto_unpack_single_arg: Python `bool` which when `False` means the single
191201
arg case will not be interpreted as a list of arguments. (See case 2.)
192202
Default value: `True`.
203+
has_aux: Whether `f(*args, **kwargs)` actually returns two outputs, the
204+
first being `y` and the second being an auxiliary output that does not get
205+
gradients computed.
193206
name: Python `str` name prefixed to ops created by this function.
194207
Default value: `None` (i.e., `'value_and_gradient'`).
195208
**kwargs: Named arguments as in `f(*args, **kwargs)` and basis for gradient.
196209
197210
Returns:
198-
y: `y = f(*args, **kwargs)`.
199-
dydx: Gradients of `y` with respect to each of `args` and `kwargs`.
211+
If `has_aux` is `False`:
212+
y: `y = f(*args, **kwargs)`.
213+
dydx: Gradients of `y` with respect to each of `args` and `kwargs`.
214+
otherwise:
215+
A tuple `((y, aux), dydx)`, where `y, aux = f(*args, **kwargs)` and `dydx`
216+
are the gradients of `y` with respect to each of `args` and `kwargs`.
200217
"""
201218
with tf.name_scope(name or 'value_and_gradient'):
202219
return _value_and_grad_impl(
@@ -207,12 +224,14 @@ def value_and_gradient_with_auto_expansion(f,
207224
output_gradients=output_gradients,
208225
auto_unpack_single_arg=auto_unpack_single_arg,
209226
expand_tf_modules_as_trainable_vars=True,
227+
has_aux=has_aux,
210228
**kwargs)
211229

212230

213231
def value_and_batch_jacobian(f,
214232
*args,
215233
auto_unpack_single_arg=True,
234+
has_aux=False,
216235
name=None,
217236
**kwargs):
218237
"""Computes `f(*args, **kwargs)` and batch Jacobian wrt to `args`, `kwargs`.
@@ -225,15 +244,23 @@ def value_and_batch_jacobian(f,
225244
auto_unpack_single_arg: Python `bool` which when `False` means the single
226245
arg case will not be interpreted as a list of arguments.
227246
Default value: `True`.
247+
has_aux: Whether `f(*args, **kwargs)` actually returns two outputs, the
248+
first being `y` and the second being an auxiliary output that does not get
249+
gradients computed.
228250
name: Python `str` name prefixed to ops created by this function.
229251
Default value: `None` (i.e., `'value_and_gradient'`).
230252
**kwargs: Named arguments as in `f(*args, **kwargs)` and basis for Jacobian.
231253
Each element must be 2D `(batch, n)`-shaped argument `Tensor`(s). If
232254
multiple are provided, a tuple of jacobians are returned.
233255
234256
Returns:
235-
y: `y = f(*args, **kwargs)`.
236-
jacobian: A `(batch, n, n)` shaped `Tensor`, `dy/dx`, or a tuple thereof.
257+
If `has_aux` is `False`:
258+
y: `y = f(*args, **kwargs)`.
259+
jacobian: A `(batch, n, n)` shaped `Tensor`, `dy/dx`, or a tuple thereof.
260+
otherwise:
261+
A tuple `((y, aux), jacobian)`, where `y, aux = f(*args, **kwargs)` and
262+
`jacobian` is a `(batch, n, n)` shaped `Tensor`, `dy/dx`, or a tuple
263+
thereof.
237264
"""
238265
with tf.name_scope(name or 'value_and_batch_jacobian'):
239266
return _value_and_grad_impl(
@@ -243,12 +270,14 @@ def value_and_batch_jacobian(f,
243270
output_gradients=None,
244271
auto_unpack_single_arg=auto_unpack_single_arg,
245272
expand_tf_modules_as_trainable_vars=False,
273+
has_aux=has_aux,
246274
**kwargs)
247275

248276

249277
def batch_jacobian(f,
250278
*args,
251279
auto_unpack_single_arg=True,
280+
has_aux=False,
252281
name=None,
253282
**kwargs):
254283
"""Computes batch Jacobian of `f(*args, **kwargs)` wrt to `args`, `kwargs`.
@@ -261,53 +290,68 @@ def batch_jacobian(f,
261290
auto_unpack_single_arg: Python `bool` which when `False` means the single
262291
arg case will not be interpreted as a list of arguments.
263292
Default value: `True`.
293+
has_aux: Whether `f(*args, **kwargs)` actually returns two outputs, the
294+
first being `y` and the second being an auxiliary output that does not get
295+
gradients computed.
264296
name: Python `str` name prefixed to ops created by this function.
265297
Default value: `None` (i.e., `'value_and_gradient'`).
266298
**kwargs: Named arguments as in `f(*args, **kwargs)` and basis for Jacobian.
267299
Each element must be 2D `(batch, n)`-shaped argument `Tensor`(s). If
268300
multiple are provided, a tuple of jacobians are returned.
269301
270302
Returns:
271-
jacobian: A `(batch, n, n)` shaped `Tensor`, `dy/dx`, or a tuple thereof.
303+
If `has_aux` is `False`:
304+
jacobian: A `(batch, n, n)` shaped `Tensor`, `dy/dx`, or a tuple thereof.
305+
otherwise:
306+
jacobian: A `(batch, n, n)` shaped `Tensor`, `dy/dx`, or a tuple thereof.
307+
aux: The auxiliary output of the function `y, aux = f(*args, **kwargs)`.
272308
"""
273-
return value_and_batch_jacobian(
309+
res = value_and_batch_jacobian(
274310
f,
275311
*args,
276312
auto_unpack_single_arg=auto_unpack_single_arg,
277313
name=name,
278-
**kwargs)[1]
314+
has_aux=has_aux,
315+
**kwargs)
316+
if has_aux:
317+
(_, aux), jacobian = res
318+
return jacobian, aux
319+
else:
320+
_, jacobian = res
321+
return jacobian
279322

280323

281324
def _gradient_new(f, xs, grad_ys):
282325
with tf.GradientTape(watch_accessed_variables=False) as tape:
283326
for x in xs:
284327
tape.watch(x)
285-
y = f()
286-
return y, tape.gradient(y, xs, output_gradients=grad_ys)
328+
y, aux = f()
329+
return y, tape.gradient(y, xs, output_gradients=grad_ys), aux
287330

288331

289332
def _gradient_old(f, xs, grad_ys):
290333
assert not tf.executing_eagerly()
291-
y = f()
292-
return y, tf.gradients(y, xs, grad_ys=grad_ys)
334+
y, aux = f()
335+
return y, tf.gradients(y, xs, grad_ys=grad_ys), aux
293336

294337

295338
def _jacobian(f, xs, grad_ys):
296339
assert grad_ys is None
297340
with tf.GradientTape(persistent=True, watch_accessed_variables=False) as tape:
298341
for x in xs:
299342
tape.watch(x)
300-
y = f()
343+
y, aux = f()
301344
try:
302-
return y, tuple(tape.batch_jacobian(y, x) for x in xs)
345+
return y, tuple(tape.batch_jacobian(y, x) for x in xs), aux
303346
except ValueError: # Fallback to for-loop jacobian.
304347
return y, tuple(tape.batch_jacobian(y, x, experimental_use_pfor=False)
305-
for x in xs)
348+
for x in xs), aux
306349

307350

308351
def _value_and_grad_impl(f, grad_fn, *args, output_gradients,
309352
auto_unpack_single_arg,
310353
expand_tf_modules_as_trainable_vars=False,
354+
has_aux=False,
311355
**kwargs):
312356
"""Helper which generalizes gradient / Jacobian."""
313357
if not args and not kwargs:
@@ -329,18 +373,31 @@ def _value_and_grad_impl(f, grad_fn, *args, output_gradients,
329373
[args, kwargs])
330374
else:
331375
expand_args, expand_kwargs = args, kwargs
332-
y, dydx = grad_fn(lambda: f(*args, **kwargs) if _has_args(f) else f(),
333-
tf.nest.flatten([expand_args, expand_kwargs]),
334-
output_gradients)
376+
377+
if not has_aux:
378+
real_f = f
379+
f = lambda *args, **kwargs: (real_f(*args, **kwargs) # pylint: disable=g-long-lambda
380+
if _has_args(real_f) else real_f(), ())
381+
382+
y, dydx, aux = grad_fn(lambda: f(*args, **kwargs) if _has_args(f) else f(),
383+
tf.nest.flatten([expand_args, expand_kwargs]),
384+
output_gradients)
335385
dydx_args, dydx_kwargs = tf.nest.pack_sequence_as(
336386
[expand_args, expand_kwargs], dydx)
337387
if len(args) == 1 and not do_unpack:
338388
dydx_args = dydx_args[0]
339-
if not kwargs:
340-
return y, dydx_args
341-
if not args:
342-
return y, dydx_kwargs
343-
return y, dydx_args, dydx_kwargs
389+
390+
if has_aux:
391+
res = ((y, aux),)
392+
else:
393+
res = (y,)
394+
395+
if args:
396+
res += (dydx_args,)
397+
if kwargs:
398+
res += (dydx_kwargs,)
399+
400+
return res
344401

345402

346403
def _prepare_args(args, kwargs):
@@ -380,8 +437,9 @@ def value_and_gradient(f, # pylint: disable=function-redefined
380437
*args,
381438
output_gradients=None,
382439
use_gradient_tape=False, # pylint: disable=unused-argument
383-
name=None, # pylint: disable=unused-argument
384440
auto_unpack_single_arg=True,
441+
has_aux=False,
442+
name=None, # pylint: disable=unused-argument
385443
**kwargs):
386444
"""Computes `f(*args)` and its gradients wrt to `*args`."""
387445
if kwargs:
@@ -392,16 +450,27 @@ def value_and_gradient(f, # pylint: disable=function-redefined
392450
if do_unpack:
393451
args = args[0]
394452
args, _ = _prepare_args(args, {})
395-
y, f_vjp = jax.vjp(f, *args)
453+
if has_aux:
454+
y, f_vjp, aux = jax.vjp(f, *args, has_aux=True)
455+
else:
456+
y, f_vjp = jax.vjp(f, *args)
396457
if output_gradients is None:
397458
output_gradients = tf.nest.map_structure(np.ones_like, y)
398459
dydx = list(f_vjp(output_gradients))
399460
if len(args) == 1 and not do_unpack:
400461
dydx = dydx[0]
401-
return y, dydx
462+
if has_aux:
463+
return (y, aux), dydx
464+
else:
465+
return y, dydx
402466

403467
def value_and_batch_jacobian( # pylint: disable=function-redefined
404-
f, *args, auto_unpack_single_arg=True, name=None, **kwargs): # pylint: disable=unused-argument
468+
f,
469+
*args,
470+
auto_unpack_single_arg=True,
471+
has_aux=False,
472+
name=None, # pylint: disable=unused-argument
473+
**kwargs):
405474
"""JAX implementation of value_and_batch_jacobian."""
406475
if kwargs:
407476
raise NotImplementedError('Jax version of `value_and_batch_jacobian` '
@@ -411,7 +480,10 @@ def value_and_batch_jacobian( # pylint: disable=function-redefined
411480
if do_unpack:
412481
args = args[0]
413482
args, _ = _prepare_args(args, {})
414-
y, f_vjp = jax.vjp(f, *args)
483+
if has_aux:
484+
y, f_vjp, aux = jax.vjp(f, *args, has_aux=True)
485+
else:
486+
y, f_vjp = jax.vjp(f, *args)
415487

416488
# Let `[B, E_1, ..., E_k]` be the shape of `y`, where the first dimension
417489
# is a batch dimension. We construct a basis for the cotangent space
@@ -426,13 +498,28 @@ def value_and_batch_jacobian( # pylint: disable=function-redefined
426498
dydx = [x.reshape(y.shape + x.shape[2:]) for x in dydx]
427499
if len(args) == 1 and not do_unpack:
428500
dydx = dydx[0]
429-
return y, dydx
501+
if has_aux:
502+
return (y, aux), dydx
503+
else:
504+
return y, dydx
430505

431506
def batch_jacobian( # pylint: disable=function-redefined
432-
f, *args, auto_unpack_single_arg=True, name=None, **kwargs): # pylint: disable=unused-argument
507+
f,
508+
*args,
509+
auto_unpack_single_arg=True,
510+
has_aux=False,
511+
name=None,
512+
**kwargs): # pylint: disable=unused-argument
433513
"""Computes the batch jacobian of `f(xs)` w.r.t. `xs`."""
434-
return value_and_batch_jacobian(
514+
res = value_and_batch_jacobian(
435515
f,
436516
*args,
437517
auto_unpack_single_arg=auto_unpack_single_arg,
438-
**kwargs)[1]
518+
has_aux=has_aux,
519+
**kwargs)
520+
if has_aux:
521+
(_, aux), jacobian = res
522+
return jacobian, aux
523+
else:
524+
_, jacobian = res
525+
return jacobian

tensorflow_probability/python/math/gradient_test.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,63 @@ def f(x, y): # [4, 2, 3], [4, 2, 1, 3] -> [4, 3, 2]
253253
self.assertAllClose(grad[0], jac[0][idx])
254254
self.assertAllClose(grad[1], jac[1][idx])
255255

256+
@test_util.numpy_disable_gradient_test
257+
def test_aux(self):
258+
x = tf.constant([[2.]])
259+
260+
def f(x):
261+
return x**2, x
262+
263+
(y, aux), dx = tfm.value_and_gradient(f, x, has_aux=True)
264+
265+
self.assertAllClose(x**2, y)
266+
self.assertAllClose(2 * x, dx)
267+
self.assertAllClose(x, aux)
268+
269+
dx, aux = batch_jacobian(f, x, has_aux=True)
270+
271+
self.assertAllClose((2 * x)[..., tf.newaxis], dx)
272+
self.assertAllClose(x, aux)
273+
274+
@test_util.numpy_disable_gradient_test
275+
def test_aux_multi_arg(self):
276+
x = tf.constant([[2.]])
277+
z = tf.constant([[3.]])
278+
279+
def f(x, z):
280+
return x**2 + z**2, (x, z)
281+
282+
(y, aux), (dx, dz) = tfm.value_and_gradient(f, (x, z), has_aux=True)
283+
284+
self.assertAllClose(x**2 + z**2, y)
285+
self.assertAllClose(2 * x, dx)
286+
self.assertAllClose(2 * z, dz)
287+
self.assertAllClose(x, aux[0])
288+
self.assertAllClose(z, aux[1])
289+
290+
(dx, dz), aux = batch_jacobian(f, (x, z), has_aux=True)
291+
292+
self.assertAllClose((2 * x)[..., tf.newaxis], dx)
293+
self.assertAllClose((2 * z)[..., tf.newaxis], dz)
294+
self.assertAllClose(x, aux[0])
295+
self.assertAllClose(z, aux[1])
296+
297+
@test_util.numpy_disable_gradient_test
298+
def test_aux_grads(self):
299+
"""Tests that gradients flow through the auxiliary output."""
300+
x = tf.constant([[2.]])
301+
302+
def f(x):
303+
return x**2, x**2
304+
305+
def f2(x):
306+
(_, aux), _ = tfm.value_and_gradient(f, x, has_aux=True)
307+
return aux
308+
309+
y, dx = tfm.value_and_gradient(f2, x)
310+
self.assertAllClose(x**2, y)
311+
self.assertAllClose(2 * x, dx)
312+
256313

257314
if __name__ == '__main__':
258315
tf.test.main()

0 commit comments

Comments
 (0)