Skip to content

Commit 5faa1ac

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
Add support for tf.Variables in TF custom gradients.
PiperOrigin-RevId: 388328639
1 parent 9b23aef commit 5faa1ac

File tree

3 files changed

+182
-4
lines changed

3 files changed

+182
-4
lines changed

tensorflow_probability/python/internal/BUILD

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,17 @@ multi_substrate_py_library(
233233
],
234234
)
235235

236+
multi_substrate_py_test(
237+
name = "custom_gradient_test",
238+
srcs = ["custom_gradient_test.py"],
239+
deps = [
240+
":custom_gradient",
241+
# tensorflow dep,
242+
"//tensorflow_probability/python/internal:test_util",
243+
"//tensorflow_probability/python/math:gradient",
244+
],
245+
)
246+
236247
py_test(
237248
name = "cache_util_test",
238249
size = "small",

tensorflow_probability/python/internal/custom_gradient.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,25 @@ def f_wrapped(*args, **kwargs):
9292
args = args[1:]
9393
val, aux = vjp_fwd(*reconstruct_args, **kwargs)
9494

95-
def vjp_bwd_wrapped(*g):
95+
def vjp_bwd_wrapped(*g, **kwargs):
96+
# We don't want to use an explicit `variables` arg, because TF will
97+
# complain if the wrapped function doesn't actually have variables
98+
# in it. TF will only specify this arg if there are variables.
99+
variables = kwargs.get('variables', ())
96100
nondiff_args = [closure[i] for i in nondiff_argnums]
97-
result = tf.nest.flatten(
98-
vjp_bwd(*nondiff_args, aux, tf.nest.pack_sequence_as(val, g)))
101+
result = vjp_bwd(*nondiff_args, aux,
102+
tf.nest.pack_sequence_as(val, g), **kwargs)
103+
if variables:
104+
result, variables = result
105+
result = tf.nest.flatten(result)
99106
for i in nondiff_argnums:
100107
result = tuple(result[:i]) + (None,) + tuple(result[i:])
101108
result = [a for i, a in enumerate(result) if i not in closure]
102-
return tf.nest.pack_sequence_as(args_structure, result)
109+
result = tf.nest.pack_sequence_as(args_structure, result)
110+
if variables:
111+
return result, variables
112+
else:
113+
return result
103114

104115
return val, vjp_bwd_wrapped
105116

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# Copyright 2021 The TensorFlow Probability Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
"""Tests for custom_gradient."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import tensorflow.compat.v2 as tf
22+
23+
from tensorflow_probability.python.internal import custom_gradient
24+
from tensorflow_probability.python.internal import test_util
25+
from tensorflow_probability.python.math import gradient as tfp_gradient
26+
27+
JAX_MODE = False
28+
29+
30+
@test_util.numpy_disable_gradient_test
31+
@test_util.test_all_tf_execution_regimes
32+
class CustomGradientTest(test_util.TestCase):
33+
34+
def testVJP(self):
35+
36+
def f_vjp_fwd(x, y):
37+
return x**2 + y**2, (x, y)
38+
39+
def f_vjp_bwd(x_y, dz):
40+
x, y = x_y
41+
return 7. * dz * x, 7. * dz * y
42+
43+
@custom_gradient.custom_gradient(
44+
vjp_fwd=f_vjp_fwd,
45+
vjp_bwd=f_vjp_bwd,
46+
)
47+
def f(x, y):
48+
return f_vjp_fwd(x, y)[0]
49+
50+
x = tf.constant(2.)
51+
y = tf.constant(3.)
52+
dz = tf.constant(5.)
53+
54+
z1 = f(x, y)
55+
z2, (dx, dy) = tfp_gradient.value_and_gradient(
56+
f, (x, y), output_gradients=dz)
57+
58+
self.assertAllClose(x**2 + y**2, z1)
59+
self.assertAllClose(x**2 + y**2, z2)
60+
self.assertAllClose(7. * dz * x, dx)
61+
self.assertAllClose(7. * dz * y, dy)
62+
63+
@test_util.jax_disable_variable_test
64+
def testVJPWithVariables(self):
65+
66+
def f_vjp_fwd(x):
67+
return x**2 + y**2, x
68+
69+
def f_vjp_bwd(x, dz, variables):
70+
y = variables[0]
71+
return 7. * dz * x, [7. * dz * y]
72+
73+
@custom_gradient.custom_gradient(
74+
vjp_fwd=f_vjp_fwd,
75+
vjp_bwd=f_vjp_bwd,
76+
)
77+
def f(x):
78+
return f_vjp_fwd(x)[0]
79+
80+
x = tf.constant(2.)
81+
y = tf.Variable(3.)
82+
dz = tf.constant(5.)
83+
84+
self.evaluate(y.initializer)
85+
86+
z1 = f(x)
87+
88+
# Use GradientTape to implicitly capture the variable.
89+
with tf.GradientTape() as tape:
90+
tape.watch(x)
91+
z2 = f(x)
92+
93+
dx, dy = tape.gradient(z2, (x, y), output_gradients=dz)
94+
95+
self.assertAllClose(x**2 + y**2, z1)
96+
self.assertAllClose(x**2 + y**2, z2)
97+
self.assertAllClose(7. * dz * x, dx)
98+
self.assertAllClose(7. * dz * y, dy)
99+
100+
def testJVP(self):
101+
if not JAX_MODE:
102+
self.skipTest('Custom JVPs are JAX-only.')
103+
104+
def f_vjp_fwd(x, y):
105+
# When a JVP is specified, this function is ignored.
106+
raise NotImplementedError()
107+
108+
def f_vjp_bwd(x_y, dz):
109+
# When a JVP is specified, this function is ignored.
110+
raise NotImplementedError()
111+
112+
def f_jvp(x_y, dx_dy):
113+
x, y = x_y
114+
dx, dy = dx_dy
115+
return f(x, y), 7. * (dx * x + dy * y)
116+
117+
@custom_gradient.custom_gradient(
118+
vjp_fwd=f_vjp_fwd,
119+
vjp_bwd=f_vjp_bwd,
120+
jvp_fn=f_jvp,
121+
)
122+
def f(x, y):
123+
return x**2 + y**2
124+
125+
x = tf.constant(2.)
126+
y = tf.constant(3.)
127+
dz = tf.constant(5.)
128+
129+
z1 = f(x, y)
130+
z2, (dx, dy) = tfp_gradient.value_and_gradient(
131+
f, (x, y), output_gradients=dz)
132+
133+
self.assertAllClose(x**2 + y**2, z1)
134+
self.assertAllClose(x**2 + y**2, z2)
135+
self.assertAllClose(7. * dz * x, dx)
136+
self.assertAllClose(7. * dz * y, dy)
137+
138+
import jax # pylint: disable=g-import-not-at-top
139+
140+
z3, dz2 = jax.jvp(f, (x, y), (dx, dy))
141+
self.assertAllClose(x**2 + y**2, z3)
142+
self.assertAllClose(7. * (dx * x + dy * y), dz2)
143+
144+
def testPreventGradient(self):
145+
146+
def f(x):
147+
return custom_gradient.prevent_gradient(x, 'No gradient')
148+
149+
_ = f(1.)
150+
151+
with self.assertRaisesRegex(LookupError, 'No gradient'):
152+
tfp_gradient.value_and_gradient(f, (1.))
153+
154+
155+
if __name__ == '__main__':
156+
tf.test.main()

0 commit comments

Comments
 (0)