Skip to content

Commit 19eb16c

Browse files
Johannes Ball?copybara-github
authored andcommitted
Fixes Keras layer reparameterization in eager mode.
Storing a "computed" layer parameter like a reparameterized kernel as an object attribute in the `build()` function breaks backpropagation in eager mode. The reason is that if `build()` is called outside (before) a `GradientTape` scope, the gradients from the layer output to its variables will be disconnected. One fix is to make `Parameterizer`s return an object that, like variables, has a `value()` method, and define the layer's parameters via `@property`s. Every time a value of a parameter is requested, these then call the `value()` method and thus recompute the parameter from its variable every time the layer is called. This way, the dependency can be tracked with the current `GradientTape` scope. PiperOrigin-RevId: 292501734 Change-Id: Idec76b516f7799c32e5b0375b09eb76a63e0372f
1 parent 44811b2 commit 19eb16c

File tree

6 files changed

+85
-67
lines changed

6 files changed

+85
-67
lines changed

tensorflow_compression/python/layers/BUILD

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ py_library(
3434
deps = [":parameterizers"],
3535
)
3636

37+
py_test(
38+
name = "gdn_test",
39+
srcs = ["gdn_test.py"],
40+
python_version = "PY3",
41+
deps = [":gdn"],
42+
)
43+
3744
py_library(
3845
name = "initializers",
3946
srcs = ["initializers.py"],
@@ -68,13 +75,6 @@ py_test(
6875
deps = [":entropy_models"],
6976
)
7077

71-
py_test(
72-
name = "gdn_test",
73-
srcs = ["gdn_test.py"],
74-
python_version = "PY3",
75-
deps = [":gdn"],
76-
)
77-
7878
py_test(
7979
name = "parameterizers_test",
8080
srcs = ["parameterizers_test.py"],

tensorflow_compression/python/layers/gdn.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(self,
8585
Defaults to `NonnegativeParameterizer` with a minimum value of 0.
8686
**kwargs: Other keyword arguments passed to superclass (`Layer`).
8787
"""
88-
super(GDN, self).__init__(**kwargs)
88+
super().__init__(**kwargs)
8989
self._inverse = bool(inverse)
9090
self._rectify = bool(rectify)
9191
self._gamma_init = float(gamma_init)
@@ -136,6 +136,14 @@ def gamma_parameterizer(self, val):
136136
"Can't set `gamma_parameterizer` once layer has been built.")
137137
self._gamma_parameterizer = val
138138

139+
@property
140+
def beta(self):
141+
return self._beta.value()
142+
143+
@property
144+
def gamma(self):
145+
return self._gamma.value()
146+
139147
def _channel_axis(self):
140148
return {"channels_first": 1, "channels_last": -1}[self.data_format]
141149

@@ -152,17 +160,17 @@ def build(self, input_shape):
152160

153161
# Sorry, lint, but these objects really are callable ...
154162
# pylint:disable=not-callable
155-
self.beta = self.beta_parameterizer(
163+
self._beta = self.beta_parameterizer(
156164
name="beta", shape=[num_channels], dtype=self.dtype,
157165
getter=self.add_weight, initializer=tf.initializers.ones())
158166

159-
self.gamma = self.gamma_parameterizer(
167+
self._gamma = self.gamma_parameterizer(
160168
name="gamma", shape=[num_channels, num_channels], dtype=self.dtype,
161169
getter=self.add_weight,
162170
initializer=tf.initializers.identity(gain=self._gamma_init))
163171
# pylint:enable=not-callable
164172

165-
self.built = True
173+
super().build(input_shape)
166174

167175
def call(self, inputs):
168176
inputs = tf.convert_to_tensor(inputs, dtype=self.dtype)
@@ -175,7 +183,9 @@ def call(self, inputs):
175183
if ndim == 2:
176184
norm_pool = tf.linalg.matmul(tf.math.square(inputs), self.gamma)
177185
norm_pool = tf.nn.bias_add(norm_pool, self.beta)
178-
elif self.data_format == "channels_last" and ndim <= 5:
186+
elif self.data_format == "channels_last" and ndim <= 4:
187+
# TODO(unassigned): This branch should also work for ndim == 5, but
188+
# currently triggers a bug in TF.
179189
shape = self.gamma.shape.as_list()
180190
gamma = tf.reshape(self.gamma, (ndim - 2) * [1] + shape)
181191
norm_pool = tf.nn.convolution(tf.math.square(inputs), gamma, "VALID")

tensorflow_compression/python/layers/gdn_test.py

Lines changed: 40 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Lint as: python3
12
# Copyright 2018 Google LLC. All Rights Reserved.
23
#
34
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,70 +15,64 @@
1415
# ==============================================================================
1516
"""Tests of GDN layer."""
1617

17-
import numpy as np
18-
import tensorflow.compat.v1 as tf
18+
import tensorflow.compat.v2 as tf
1919

20-
from tensorflow.python.framework import test_util
2120
from tensorflow_compression.python.layers import gdn
2221

2322

24-
@test_util.deprecated_graph_mode_only
2523
class GDNTest(tf.test.TestCase):
2624

27-
def _run_gdn(self, x, shape, inverse, rectify, data_format):
28-
inputs = tf.placeholder(tf.float32, shape)
29-
layer = gdn.GDN(
30-
inverse=inverse, rectify=rectify, data_format=data_format)
31-
outputs = layer(inputs)
32-
with self.cached_session() as sess:
33-
tf.global_variables_initializer().run()
34-
y, = sess.run([outputs], {inputs: x})
35-
return y
36-
37-
def test_invalid_data_format(self):
38-
x = np.random.uniform(size=(1, 2, 3, 4))
25+
def test_invalid_data_format_raises_error(self):
26+
x = tf.random.uniform((1, 2, 3, 4), dtype=tf.float32)
3927
with self.assertRaises(ValueError):
40-
self._run_gdn(x, x.shape, False, False, "NHWC")
28+
gdn.GDN(inverse=False, rectify=False, data_format="NHWC")(x)
4129

42-
def test_unknown_dim(self):
43-
x = np.random.uniform(size=(1, 2, 3, 4))
30+
def test_vector_input_raises_error(self):
31+
x = tf.random.uniform((3,), dtype=tf.float32)
32+
with self.assertRaises(ValueError):
33+
gdn.GDN(inverse=False, rectify=False, data_format="channels_last")(x)
4434
with self.assertRaises(ValueError):
45-
self._run_gdn(x, 4 * [None], False, False, "channels_last")
35+
gdn.GDN(inverse=True, rectify=True, data_format="channels_first")(x)
4636

47-
def test_channels_last(self):
37+
def test_channels_last_has_correct_output(self):
38+
# This tests that the layer produces the correct output for a number of
39+
# different input dimensionalities with 'channels_last' data format.
4840
for ndim in [2, 3, 4, 5, 6]:
49-
x = np.random.uniform(size=(1, 2, 3, 4, 5, 6)[:ndim])
50-
y = self._run_gdn(x, x.shape, False, False, "channels_last")
41+
x = tf.random.uniform((1, 2, 3, 4, 5, 6)[:ndim], dtype=tf.float32)
42+
y = gdn.GDN(inverse=False, rectify=False, data_format="channels_last")(x)
5143
self.assertEqual(x.shape, y.shape)
52-
self.assertAllClose(y, x / np.sqrt(1 + .1 * (x ** 2)), rtol=0, atol=1e-6)
44+
self.assertAllClose(y, x / tf.sqrt(1 + .1 * (x ** 2)), rtol=0, atol=1e-6)
5345

54-
def test_channels_first(self):
46+
def test_channels_first_has_correct_output(self):
47+
# This tests that the layer produces the correct output for a number of
48+
# different input dimensionalities with 'channels_first' data format.
5549
for ndim in [2, 3, 4, 5, 6]:
56-
x = np.random.uniform(size=(6, 5, 4, 3, 2, 1)[:ndim])
57-
y = self._run_gdn(x, x.shape, False, False, "channels_first")
50+
x = tf.random.uniform((6, 5, 4, 3, 2, 1)[:ndim], dtype=tf.float32)
51+
y = gdn.GDN(inverse=False, rectify=False, data_format="channels_first")(x)
5852
self.assertEqual(x.shape, y.shape)
59-
self.assertAllClose(
60-
y, x / np.sqrt(1 + .1 * (x ** 2)), rtol=0, atol=1e-6)
61-
62-
def test_wrong_dims(self):
63-
x = np.random.uniform(size=(3,))
64-
with self.assertRaises(ValueError):
65-
self._run_gdn(x, x.shape, False, False, "channels_last")
66-
with self.assertRaises(ValueError):
67-
self._run_gdn(x, x.shape, True, True, "channels_first")
53+
self.assertAllClose(y, x / tf.sqrt(1 + .1 * (x ** 2)), rtol=0, atol=1e-6)
6854

69-
def test_igdn(self):
70-
x = np.random.uniform(size=(1, 2, 3, 4))
71-
y = self._run_gdn(x, x.shape, True, False, "channels_last")
55+
def test_igdn_has_correct_output(self):
56+
x = tf.random.uniform((1, 2, 3, 4), dtype=tf.float32)
57+
y = gdn.GDN(inverse=True, rectify=False)(x)
7258
self.assertEqual(x.shape, y.shape)
73-
self.assertAllClose(y, x * np.sqrt(1 + .1 * (x ** 2)), rtol=0, atol=1e-6)
59+
self.assertAllClose(y, x * tf.sqrt(1 + .1 * (x ** 2)), rtol=0, atol=1e-6)
7460

75-
def test_rgdn(self):
76-
x = np.random.uniform(-.5, .5, size=(1, 2, 3, 4))
77-
y = self._run_gdn(x, x.shape, False, True, "channels_last")
61+
def test_rgdn_has_correct_output(self):
62+
x = tf.random.uniform((1, 2, 3, 4), -.5, .5, dtype=tf.float32)
63+
y = gdn.GDN(inverse=False, rectify=True)(x)
7864
self.assertEqual(x.shape, y.shape)
79-
x = np.maximum(x, 0)
80-
self.assertAllClose(y, x / np.sqrt(1 + .1 * (x ** 2)), rtol=0, atol=1e-6)
65+
x = tf.maximum(x, 0)
66+
self.assertAllClose(y, x / tf.sqrt(1 + .1 * (x ** 2)), rtol=0, atol=1e-6)
67+
68+
def test_variables_receive_gradients(self):
69+
x = tf.random.uniform((1, 2), dtype=tf.float32)
70+
layer = gdn.GDN(inverse=False, rectify=True)
71+
with tf.GradientTape() as g:
72+
y = layer(x)
73+
grads = g.gradient(y, layer.trainable_variables)
74+
self.assertLen(grads, 2)
75+
self.assertNotIn(None, grads)
8176

8277

8378
if __name__ == "__main__":

tensorflow_compression/python/layers/parameterizers.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,28 @@
2121

2222

2323
__all__ = [
24+
"Parameter",
2425
"Parameterizer",
2526
"StaticParameterizer",
2627
"RDFTParameterizer",
2728
"NonnegativeParameterizer",
2829
]
2930

3031

32+
class Parameter(object):
33+
"""Reparameterized `Layer` variable.
34+
35+
This object represents a parameter of a `tf.keras.layer.Layer` object which
36+
isn't directly stored in a `tf.Variable`. Instead, the value is computed
37+
on-demand by calling its `value()` method.
38+
"""
39+
40+
def __init__(self, value):
41+
if not callable(value):
42+
raise TypeError("`value` must be callable without arguments.")
43+
self.value = value
44+
45+
3146
class Parameterizer(object):
3247
"""Parameterization object (abstract base class).
3348
@@ -69,9 +84,9 @@ def __call__(self, getter, name, shape, dtype, initializer, regularizer=None):
6984
"static parameterizers.")
7085
if callable(self.value):
7186
# Treat value as initializer.
72-
return self.value(shape, dtype=dtype)
87+
return Parameter(lambda: self.value(shape, dtype=dtype))
7388
else:
74-
return self.value
89+
return Parameter(lambda: self.value)
7590

7691

7792
class RDFTParameterizer(Parameterizer):
@@ -137,7 +152,7 @@ def reparam(rdft):
137152
rdft = getter(
138153
name=rdft_name, shape=rdft_shape, dtype=rdft_dtype,
139154
initializer=rdft_initializer, regularizer=reparam_regularizer)
140-
return reparam(rdft)
155+
return Parameter(lambda: reparam(rdft))
141156

142157

143158
class NonnegativeParameterizer(Parameterizer):
@@ -194,4 +209,4 @@ def reparam(var):
194209
var = getter(
195210
name=reparam_name, shape=shape, dtype=dtype,
196211
initializer=reparam_initializer, regularizer=reparam_regularizer)
197-
return reparam(var)
212+
return Parameter(lambda: reparam(var))

tensorflow_compression/python/layers/parameterizers_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class ParameterizersTest(tf.test.TestCase):
2727
def _test_parameterizer(self, param, init, shape):
2828
var = param(
2929
getter=tf.get_variable, name="test", shape=shape, dtype=tf.float32,
30-
initializer=init, regularizer=None)
30+
initializer=init, regularizer=None).value()
3131
with self.cached_session() as sess:
3232
tf.global_variables_initializer().run()
3333
var, = sess.run([var])

tensorflow_compression/python/layers/signal_conv.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -350,11 +350,11 @@ def bias_parameterizer(self):
350350

351351
@property
352352
def kernel(self):
353-
return self._kernel
353+
return self._kernel.value()
354354

355355
@property
356356
def bias(self):
357-
return self._bias
357+
return self._bias.value()
358358

359359
@property
360360
def _op_data_format(self):
@@ -432,8 +432,6 @@ def build(self, input_shape):
432432
self._bias = getter(
433433
name="bias", shape=(output_channels,), dtype=self.dtype,
434434
initializer=self.bias_initializer, regularizer=self.bias_regularizer)
435-
else:
436-
self._bias = None
437435

438436
super(_SignalConv, self).build(input_shape)
439437

@@ -778,7 +776,7 @@ def call(self, inputs):
778776
self._raise_notimplemented()
779777

780778
# Now, add bias if requested.
781-
if self.bias is not None:
779+
if self.use_bias:
782780
bias = self.bias
783781
if self.data_format == "channels_first":
784782
# As of Mar 2017, direct addition is significantly slower than

0 commit comments

Comments
 (0)