Skip to content

Commit 57d41b7

Browse files
Johannes Ballécopybara-github
authored andcommitted
Adds the ability to train the exponents of GDN.
PiperOrigin-RevId: 355011064 Change-Id: Ib1cf5836cbf71be7e12765243c768e2d987df3fa
1 parent 89a4efe commit 57d41b7

File tree

2 files changed

+226
-36
lines changed

2 files changed

+226
-36
lines changed

tensorflow_compression/python/layers/gdn.py

Lines changed: 171 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
]
2525

2626

27+
ParameterType = Union[None, tf.Tensor, Callable[[], tf.Tensor]]
28+
29+
2730
@tf.keras.utils.register_keras_serializable(package="tensorflow_compression")
2831
class GDN(tf.keras.layers.Layer):
2932
"""Generalized divisive normalization layer.
@@ -39,16 +42,17 @@ class GDN(tf.keras.layers.Layer):
3942
> J. Ballé, V. Laparra, E.P. Simoncelli<br />
4043
> https://arxiv.org/abs/1611.01704
4144
42-
Implements an activation function that is essentially a multivariate
43-
generalization of a particular sigmoid-type function:
45+
Implements an activation function that is a multivariate generalization of a
46+
particular sigmoid-type function:
4447
4548
```
46-
y[i] = x[i] / sqrt(beta[i] + sum_j(gamma[j, i] * x[j]^2))
49+
y[i] = x[i] / (beta[i] + sum_j(gamma[j, i] * |x[j]|^alpha))^epsilon
4750
```
4851
4952
where `i` and `j` run over channels. This implementation never sums across
5053
spatial dimensions. It is similar to local response normalization, but much
51-
more flexible, as `beta` and `gamma` are trainable parameters.
54+
more flexible, as `alpha`, `beta`, `gamma`, and `epsilon` are trainable
55+
parameters.
5256
5357
Attributes:
5458
inverse: Boolean. If `False`, compute GDN response. If `True`, compute IGDN
@@ -58,6 +62,14 @@ class GDN(tf.keras.layers.Layer):
5862
before calculating GDN response.
5963
data_format: String. Format of input tensor. Currently supports
6064
`'channels_first'` and `'channels_last'`.
65+
alpha_parameter: Scalar, callable, or `None`. A number or scalar `tf.Tensor`
66+
means that the value of alpha is fixed. A callable can be used to
67+
determine the value of alpha as a function of some other variable or
68+
tensor. This can be a `Parameter` object. `None` means that when the layer
69+
is built, a `GDNParameter` object is created to train alpha (with a
70+
minimum value of 1). The default is a fixed value of 1. Note that certain
71+
choices here such as `tf.Tensor`s or lambda functions may prevent
72+
JSON-style serialization (`Parameter` objects and Python constants work).
6173
beta_parameter: Tensor, callable, or `None`. A `tf.Tensor` means that the
6274
value of beta is fixed. A callable can be used to determine the value of
6375
beta as a function of some other variable or tensor. This can be a
@@ -72,53 +84,86 @@ class GDN(tf.keras.layers.Layer):
7284
`GDNParameter` object is created to train gamma. Note that certain choices
7385
here such as `tf.Tensor`s or lambda functions may prevent JSON-style
7486
serialization (`Parameter` objects work).
87+
epsilon_parameter: Scalar, callable, or `None`. A number or scalar
88+
`tf.Tensor` means that the value of epsilon is fixed. A callable can be
89+
used to determine the value of epsilon as a function of some other
90+
variable or tensor. This can be a `Parameter` object. `None` means that
91+
when the layer is built, a `GDNParameter` object is created to train
92+
epsilon (with a minimum value of 1e-6). The default is a fixed value of 1.
93+
Note that certain choices here such as `tf.Tensor`s or lambda functions
94+
may prevent JSON-style serialization (`Parameter` objects and Python
95+
constants work).
96+
alpha_initializer: `Initializer` object for alpha parameter. Only used if
97+
alpha is trained. Defaults to 1.
7598
beta_initializer: `Initializer` object for beta parameter. Only used if beta
7699
is created on building the layer. Defaults to 1.
77100
gamma_initializer: `Initializer` object for gamma parameter. Only used if
78101
gamma is created on building the layer. Defaults to identity matrix
79102
multiplied by 0.1. A good default value for the diagonal is somewhere
80103
between 0 and 0.5. If set to 0 and beta initialized as 1, the layer is
81104
effectively initialized to the identity operation.
105+
epsilon_initializer: `Initializer` object for epsilon parameter. Only used
106+
if epsilon is trained. Defaults to 1.
107+
alpha: `tf.Tensor`. Read-only property always returning the current value of
108+
alpha.
82109
beta: `tf.Tensor`. Read-only property always returning the current value of
83110
beta.
84111
gamma: `tf.Tensor`. Read-only property always returning the current value of
85112
gamma.
113+
epsilon: `tf.Tensor`. Read-only property always returning the current value
114+
of epsilon.
86115
"""
87116

88117
def __init__(self,
89118
inverse=False,
90119
rectify=False,
91120
data_format="channels_last",
121+
alpha_parameter=1,
92122
beta_parameter=None,
93123
gamma_parameter=None,
124+
epsilon_parameter=1,
125+
alpha_initializer="ones",
94126
beta_initializer="ones",
95127
gamma_initializer=tf.keras.initializers.Identity(.1),
128+
epsilon_initializer="ones",
96129
**kwargs):
97130
"""Initializer.
98131
99132
Args:
100133
inverse: Boolean. Initial value of eponymous attribute.
101134
rectify: Boolean. Initial value of eponymous attribute.
102135
data_format: String. Initial value of eponymous attribute.
136+
alpha_parameter: Scalar, callable, or `None`. Initial value of eponymous
137+
attribute.
103138
beta_parameter: Tensor, callable, or `None`. Initial value of eponymous
104139
attribute.
105140
gamma_parameter: Tensor, callable, or `None`. Initial value of eponymous
106141
attribute.
142+
epsilon_parameter: Scalar, callable, or `None`. Initial value of
143+
eponymous attribute.
144+
alpha_initializer: `Initializer` object. Initial value of eponymous
145+
attribute.
107146
beta_initializer: `Initializer` object. Initial value of eponymous
108147
attribute.
109148
gamma_initializer: `Initializer` object. Initial value of eponymous
110149
attribute.
150+
epsilon_initializer: `Initializer` object. Initial value of eponymous
151+
attribute.
111152
**kwargs: Other keyword arguments passed to superclass (`Layer`).
112153
"""
113154
super().__init__(**kwargs)
114155
self.input_spec = tf.keras.layers.InputSpec(min_ndim=2)
115156
self.inverse = inverse
116157
self.rectify = rectify
117158
self.data_format = data_format
159+
self.alpha_parameter = alpha_parameter
118160
self.beta_parameter = beta_parameter
119161
self.gamma_parameter = gamma_parameter
162+
self.epsilon_parameter = epsilon_parameter
163+
self.alpha_initializer = alpha_initializer
120164
self.beta_initializer = beta_initializer
121165
self.gamma_initializer = gamma_initializer
166+
self.epsilon_initializer = epsilon_initializer
122167

123168
def _check_not_built(self):
124169
if self.built:
@@ -156,7 +201,21 @@ def data_format(self, value):
156201
self._data_format = value
157202

158203
@property
159-
def beta_parameter(self) -> Union[None, tf.Tensor, Callable[[], tf.Tensor]]:
204+
def alpha_parameter(self) -> ParameterType:
205+
return self._alpha_parameter
206+
207+
@alpha_parameter.setter
208+
def alpha_parameter(self, value):
209+
self._check_not_built()
210+
# This is necessary to make Keras deserialization via __init__ work.
211+
if isinstance(value, dict):
212+
value = tf.keras.utils.deserialize_keras_object(value)
213+
if value is not None and not callable(value):
214+
value = tf.convert_to_tensor(value, dtype=self.dtype)
215+
self._alpha_parameter = value
216+
217+
@property
218+
def beta_parameter(self) -> ParameterType:
160219
return self._beta_parameter
161220

162221
@beta_parameter.setter
@@ -170,7 +229,7 @@ def beta_parameter(self, value):
170229
self._beta_parameter = value
171230

172231
@property
173-
def gamma_parameter(self) -> Union[None, tf.Tensor, Callable[[], tf.Tensor]]:
232+
def gamma_parameter(self) -> ParameterType:
174233
return self._gamma_parameter
175234

176235
@gamma_parameter.setter
@@ -183,6 +242,29 @@ def gamma_parameter(self, value):
183242
value = tf.convert_to_tensor(value, dtype=self.dtype)
184243
self._gamma_parameter = value
185244

245+
@property
246+
def epsilon_parameter(self) -> ParameterType:
247+
return self._epsilon_parameter
248+
249+
@epsilon_parameter.setter
250+
def epsilon_parameter(self, value):
251+
self._check_not_built()
252+
# This is necessary to make Keras deserialization via __init__ work.
253+
if isinstance(value, dict):
254+
value = tf.keras.utils.deserialize_keras_object(value)
255+
if value is not None and not callable(value):
256+
value = tf.convert_to_tensor(value, dtype=self.dtype)
257+
self._epsilon_parameter = value
258+
259+
@property
260+
def alpha_initializer(self) -> Callable[..., tf.Tensor]:
261+
return self._alpha_initializer
262+
263+
@alpha_initializer.setter
264+
def alpha_initializer(self, value):
265+
self._check_not_built()
266+
self._alpha_initializer = tf.keras.initializers.get(value)
267+
186268
@property
187269
def beta_initializer(self) -> Callable[..., tf.Tensor]:
188270
return self._beta_initializer
@@ -201,6 +283,23 @@ def gamma_initializer(self, value):
201283
self._check_not_built()
202284
self._gamma_initializer = tf.keras.initializers.get(value)
203285

286+
@property
287+
def epsilon_initializer(self) -> Callable[..., tf.Tensor]:
288+
return self._epsilon_initializer
289+
290+
@epsilon_initializer.setter
291+
def epsilon_initializer(self, value):
292+
self._check_not_built()
293+
self._epsilon_initializer = tf.keras.initializers.get(value)
294+
295+
@property
296+
def alpha(self) -> tf.Tensor:
297+
if self.alpha_parameter is None:
298+
raise RuntimeError("alpha is not initialized yet. Call build().")
299+
if callable(self.alpha_parameter):
300+
return self.alpha_parameter()
301+
return tf.convert_to_tensor(self.alpha_parameter, dtype=self.dtype)
302+
204303
@property
205304
def beta(self) -> tf.Tensor:
206305
if self.beta_parameter is None:
@@ -217,6 +316,14 @@ def gamma(self) -> tf.Tensor:
217316
return tf.convert_to_tensor(self.gamma_parameter(), dtype=self.dtype)
218317
return self.gamma_parameter
219318

319+
@property
320+
def epsilon(self) -> tf.Tensor:
321+
if self.epsilon_parameter is None:
322+
raise RuntimeError("epsilon is not initialized yet. Call build().")
323+
if callable(self.epsilon_parameter):
324+
return self.epsilon_parameter()
325+
return tf.convert_to_tensor(self.epsilon_parameter, dtype=self.dtype)
326+
220327
@property
221328
def _channel_axis(self):
222329
return {"channels_first": 1, "channels_last": -1}[self.data_format]
@@ -231,6 +338,12 @@ def build(self, input_shape):
231338
self.input_spec = tf.keras.layers.InputSpec(
232339
min_ndim=2, axes={channel_axis: num_channels})
233340

341+
if self.alpha_parameter is None:
342+
initial_value = self.alpha_initializer(
343+
shape=[], dtype=self.dtype)
344+
self.alpha_parameter = parameters.GDNParameter(
345+
initial_value, name="alpha", minimum=1)
346+
234347
if self.beta_parameter is None:
235348
initial_value = self.beta_initializer(
236349
shape=[num_channels], dtype=self.dtype)
@@ -243,6 +356,12 @@ def build(self, input_shape):
243356
self.gamma_parameter = parameters.GDNParameter(
244357
initial_value, name="gamma", minimum=0)
245358

359+
if self.epsilon_parameter is None:
360+
initial_value = self.epsilon_initializer(
361+
shape=[], dtype=self.dtype)
362+
self.epsilon_parameter = parameters.GDNParameter(
363+
initial_value, name="epsilon", minimum=1e-6)
364+
246365
super().build(input_shape)
247366

248367
def call(self, inputs) -> tf.Tensor:
@@ -254,49 +373,88 @@ def call(self, inputs) -> tf.Tensor:
254373
if self.rectify:
255374
inputs = tf.nn.relu(inputs)
256375

376+
# Optimize for fixed alphas.
377+
if not callable(self.alpha_parameter) and self.alpha == 1 and self.rectify:
378+
norm_pool = inputs
379+
elif not callable(self.alpha_parameter) and self.alpha == 1:
380+
norm_pool = abs(inputs)
381+
elif not callable(self.alpha_parameter) and self.alpha == 2:
382+
norm_pool = tf.math.square(inputs)
383+
else:
384+
norm_pool = inputs ** self.alpha
385+
257386
# Compute normalization pool.
258387
if rank == 2:
259-
norm_pool = tf.linalg.matmul(tf.math.square(inputs), self.gamma)
388+
norm_pool = tf.linalg.matmul(norm_pool, self.gamma)
260389
norm_pool = tf.nn.bias_add(norm_pool, self.beta)
261390
elif self.data_format == "channels_last" and rank <= 5:
262391
shape = self.gamma.shape
263392
gamma = tf.reshape(self.gamma, (rank - 2) * [1] + shape)
264-
norm_pool = tf.nn.convolution(
265-
tf.math.square(inputs), gamma, padding="VALID")
393+
norm_pool = tf.nn.convolution(norm_pool, gamma, padding="VALID")
266394
norm_pool = tf.nn.bias_add(norm_pool, self.beta)
267395
else: # generic implementation
268396
# This puts channels in the last dimension regardless of input.
269397
norm_pool = tf.linalg.tensordot(
270-
tf.math.square(inputs), self.gamma, [[self._channel_axis], [0]])
398+
norm_pool, self.gamma, [[self._channel_axis], [0]])
271399
norm_pool += self.beta
272400
if self.data_format == "channels_first":
273401
# Return to channels_first format if necessary.
274402
axes = list(range(rank - 1))
275403
axes.insert(1, rank - 1)
276404
norm_pool = tf.transpose(norm_pool, axes)
277405

278-
if self.inverse:
406+
# Optimize for fixed epsilons.
407+
if not callable(self.epsilon_parameter) and self.epsilon == 1:
408+
pass
409+
elif not callable(self.epsilon_parameter) and self.epsilon == .5:
279410
norm_pool = tf.math.sqrt(norm_pool)
280411
else:
281-
norm_pool = tf.math.rsqrt(norm_pool)
282-
return inputs * norm_pool
412+
norm_pool = norm_pool ** self.epsilon
413+
414+
if self.inverse:
415+
return inputs * norm_pool
416+
else:
417+
return inputs / norm_pool
283418

284419
def compute_output_shape(self, input_shape) -> tf.TensorShape:
285420
return tf.TensorShape(input_shape)
286421

287422
def get_config(self) -> Dict[str, Any]:
288423
config = super().get_config()
424+
425+
# Since alpha and epsilon are scalar, allow fixed values to be serialized.
426+
def try_serialize(parameter, name):
427+
try:
428+
return tf.keras.utils.serialize_keras_object(parameter)
429+
except (ValueError, TypeError): # Should throw TypeError, but doesn't...
430+
try:
431+
return float(parameter)
432+
except TypeError:
433+
raise TypeError(
434+
f"Can't serialize {name} of type '{type(parameter)}'.")
435+
436+
alpha_parameter = try_serialize(
437+
self.alpha_parameter, "alpha_parameter")
438+
epsilon_parameter = try_serialize(
439+
self.epsilon_parameter, "epsilon_parameter")
440+
289441
config.update(
290442
inverse=self.inverse,
291443
rectify=self.rectify,
292444
data_format=self.data_format,
445+
alpha_parameter=alpha_parameter,
293446
beta_parameter=tf.keras.utils.serialize_keras_object(
294447
self.beta_parameter),
295448
gamma_parameter=tf.keras.utils.serialize_keras_object(
296449
self.gamma_parameter),
450+
epsilon_parameter=epsilon_parameter,
451+
alpha_initializer=tf.keras.initializers.serialize(
452+
self.alpha_initializer),
297453
beta_initializer=tf.keras.initializers.serialize(
298454
self.beta_initializer),
299455
gamma_initializer=tf.keras.initializers.serialize(
300456
self.gamma_initializer),
457+
epsilon_initializer=tf.keras.initializers.serialize(
458+
self.epsilon_initializer),
301459
)
302460
return config

0 commit comments

Comments
 (0)