24
24
]
25
25
26
26
27
+ ParameterType = Union [None , tf .Tensor , Callable [[], tf .Tensor ]]
28
+
29
+
27
30
@tf .keras .utils .register_keras_serializable (package = "tensorflow_compression" )
28
31
class GDN (tf .keras .layers .Layer ):
29
32
"""Generalized divisive normalization layer.
@@ -39,16 +42,17 @@ class GDN(tf.keras.layers.Layer):
39
42
> J. Ballé, V. Laparra, E.P. Simoncelli<br />
40
43
> https://arxiv.org/abs/1611.01704
41
44
42
- Implements an activation function that is essentially a multivariate
43
- generalization of a particular sigmoid-type function:
45
+ Implements an activation function that is a multivariate generalization of a
46
+ particular sigmoid-type function:
44
47
45
48
```
46
- y[i] = x[i] / sqrt (beta[i] + sum_j(gamma[j, i] * x[j]^2))
49
+ y[i] = x[i] / (beta[i] + sum_j(gamma[j, i] * | x[j]|^alpha))^epsilon
47
50
```
48
51
49
52
where `i` and `j` run over channels. This implementation never sums across
50
53
spatial dimensions. It is similar to local response normalization, but much
51
- more flexible, as `beta` and `gamma` are trainable parameters.
54
+ more flexible, as `alpha`, `beta`, `gamma`, and `epsilon` are trainable
55
+ parameters.
52
56
53
57
Attributes:
54
58
inverse: Boolean. If `False`, compute GDN response. If `True`, compute IGDN
@@ -58,6 +62,14 @@ class GDN(tf.keras.layers.Layer):
58
62
before calculating GDN response.
59
63
data_format: String. Format of input tensor. Currently supports
60
64
`'channels_first'` and `'channels_last'`.
65
+ alpha_parameter: Scalar, callable, or `None`. A number or scalar `tf.Tensor`
66
+ means that the value of alpha is fixed. A callable can be used to
67
+ determine the value of alpha as a function of some other variable or
68
+ tensor. This can be a `Parameter` object. `None` means that when the layer
69
+ is built, a `GDNParameter` object is created to train alpha (with a
70
+ minimum value of 1). The default is a fixed value of 1. Note that certain
71
+ choices here such as `tf.Tensor`s or lambda functions may prevent
72
+ JSON-style serialization (`Parameter` objects and Python constants work).
61
73
beta_parameter: Tensor, callable, or `None`. A `tf.Tensor` means that the
62
74
value of beta is fixed. A callable can be used to determine the value of
63
75
beta as a function of some other variable or tensor. This can be a
@@ -72,53 +84,86 @@ class GDN(tf.keras.layers.Layer):
72
84
`GDNParameter` object is created to train gamma. Note that certain choices
73
85
here such as `tf.Tensor`s or lambda functions may prevent JSON-style
74
86
serialization (`Parameter` objects work).
87
+ epsilon_parameter: Scalar, callable, or `None`. A number or scalar
88
+ `tf.Tensor` means that the value of epsilon is fixed. A callable can be
89
+ used to determine the value of epsilon as a function of some other
90
+ variable or tensor. This can be a `Parameter` object. `None` means that
91
+ when the layer is built, a `GDNParameter` object is created to train
92
+ epsilon (with a minimum value of 1e-6). The default is a fixed value of 1.
93
+ Note that certain choices here such as `tf.Tensor`s or lambda functions
94
+ may prevent JSON-style serialization (`Parameter` objects and Python
95
+ constants work).
96
+ alpha_initializer: `Initializer` object for alpha parameter. Only used if
97
+ alpha is trained. Defaults to 1.
75
98
beta_initializer: `Initializer` object for beta parameter. Only used if beta
76
99
is created on building the layer. Defaults to 1.
77
100
gamma_initializer: `Initializer` object for gamma parameter. Only used if
78
101
gamma is created on building the layer. Defaults to identity matrix
79
102
multiplied by 0.1. A good default value for the diagonal is somewhere
80
103
between 0 and 0.5. If set to 0 and beta initialized as 1, the layer is
81
104
effectively initialized to the identity operation.
105
+ epsilon_initializer: `Initializer` object for epsilon parameter. Only used
106
+ if epsilon is trained. Defaults to 1.
107
+ alpha: `tf.Tensor`. Read-only property always returning the current value of
108
+ alpha.
82
109
beta: `tf.Tensor`. Read-only property always returning the current value of
83
110
beta.
84
111
gamma: `tf.Tensor`. Read-only property always returning the current value of
85
112
gamma.
113
+ epsilon: `tf.Tensor`. Read-only property always returning the current value
114
+ of epsilon.
86
115
"""
87
116
88
117
def __init__ (self ,
89
118
inverse = False ,
90
119
rectify = False ,
91
120
data_format = "channels_last" ,
121
+ alpha_parameter = 1 ,
92
122
beta_parameter = None ,
93
123
gamma_parameter = None ,
124
+ epsilon_parameter = 1 ,
125
+ alpha_initializer = "ones" ,
94
126
beta_initializer = "ones" ,
95
127
gamma_initializer = tf .keras .initializers .Identity (.1 ),
128
+ epsilon_initializer = "ones" ,
96
129
** kwargs ):
97
130
"""Initializer.
98
131
99
132
Args:
100
133
inverse: Boolean. Initial value of eponymous attribute.
101
134
rectify: Boolean. Initial value of eponymous attribute.
102
135
data_format: String. Initial value of eponymous attribute.
136
+ alpha_parameter: Scalar, callable, or `None`. Initial value of eponymous
137
+ attribute.
103
138
beta_parameter: Tensor, callable, or `None`. Initial value of eponymous
104
139
attribute.
105
140
gamma_parameter: Tensor, callable, or `None`. Initial value of eponymous
106
141
attribute.
142
+ epsilon_parameter: Scalar, callable, or `None`. Initial value of
143
+ eponymous attribute.
144
+ alpha_initializer: `Initializer` object. Initial value of eponymous
145
+ attribute.
107
146
beta_initializer: `Initializer` object. Initial value of eponymous
108
147
attribute.
109
148
gamma_initializer: `Initializer` object. Initial value of eponymous
110
149
attribute.
150
+ epsilon_initializer: `Initializer` object. Initial value of eponymous
151
+ attribute.
111
152
**kwargs: Other keyword arguments passed to superclass (`Layer`).
112
153
"""
113
154
super ().__init__ (** kwargs )
114
155
self .input_spec = tf .keras .layers .InputSpec (min_ndim = 2 )
115
156
self .inverse = inverse
116
157
self .rectify = rectify
117
158
self .data_format = data_format
159
+ self .alpha_parameter = alpha_parameter
118
160
self .beta_parameter = beta_parameter
119
161
self .gamma_parameter = gamma_parameter
162
+ self .epsilon_parameter = epsilon_parameter
163
+ self .alpha_initializer = alpha_initializer
120
164
self .beta_initializer = beta_initializer
121
165
self .gamma_initializer = gamma_initializer
166
+ self .epsilon_initializer = epsilon_initializer
122
167
123
168
def _check_not_built (self ):
124
169
if self .built :
@@ -156,7 +201,21 @@ def data_format(self, value):
156
201
self ._data_format = value
157
202
158
203
@property
159
- def beta_parameter (self ) -> Union [None , tf .Tensor , Callable [[], tf .Tensor ]]:
204
+ def alpha_parameter (self ) -> ParameterType :
205
+ return self ._alpha_parameter
206
+
207
+ @alpha_parameter .setter
208
+ def alpha_parameter (self , value ):
209
+ self ._check_not_built ()
210
+ # This is necessary to make Keras deserialization via __init__ work.
211
+ if isinstance (value , dict ):
212
+ value = tf .keras .utils .deserialize_keras_object (value )
213
+ if value is not None and not callable (value ):
214
+ value = tf .convert_to_tensor (value , dtype = self .dtype )
215
+ self ._alpha_parameter = value
216
+
217
+ @property
218
+ def beta_parameter (self ) -> ParameterType :
160
219
return self ._beta_parameter
161
220
162
221
@beta_parameter .setter
@@ -170,7 +229,7 @@ def beta_parameter(self, value):
170
229
self ._beta_parameter = value
171
230
172
231
@property
173
- def gamma_parameter (self ) -> Union [ None , tf . Tensor , Callable [[], tf . Tensor ]] :
232
+ def gamma_parameter (self ) -> ParameterType :
174
233
return self ._gamma_parameter
175
234
176
235
@gamma_parameter .setter
@@ -183,6 +242,29 @@ def gamma_parameter(self, value):
183
242
value = tf .convert_to_tensor (value , dtype = self .dtype )
184
243
self ._gamma_parameter = value
185
244
245
+ @property
246
+ def epsilon_parameter (self ) -> ParameterType :
247
+ return self ._epsilon_parameter
248
+
249
+ @epsilon_parameter .setter
250
+ def epsilon_parameter (self , value ):
251
+ self ._check_not_built ()
252
+ # This is necessary to make Keras deserialization via __init__ work.
253
+ if isinstance (value , dict ):
254
+ value = tf .keras .utils .deserialize_keras_object (value )
255
+ if value is not None and not callable (value ):
256
+ value = tf .convert_to_tensor (value , dtype = self .dtype )
257
+ self ._epsilon_parameter = value
258
+
259
+ @property
260
+ def alpha_initializer (self ) -> Callable [..., tf .Tensor ]:
261
+ return self ._alpha_initializer
262
+
263
+ @alpha_initializer .setter
264
+ def alpha_initializer (self , value ):
265
+ self ._check_not_built ()
266
+ self ._alpha_initializer = tf .keras .initializers .get (value )
267
+
186
268
@property
187
269
def beta_initializer (self ) -> Callable [..., tf .Tensor ]:
188
270
return self ._beta_initializer
@@ -201,6 +283,23 @@ def gamma_initializer(self, value):
201
283
self ._check_not_built ()
202
284
self ._gamma_initializer = tf .keras .initializers .get (value )
203
285
286
+ @property
287
+ def epsilon_initializer (self ) -> Callable [..., tf .Tensor ]:
288
+ return self ._epsilon_initializer
289
+
290
+ @epsilon_initializer .setter
291
+ def epsilon_initializer (self , value ):
292
+ self ._check_not_built ()
293
+ self ._epsilon_initializer = tf .keras .initializers .get (value )
294
+
295
+ @property
296
+ def alpha (self ) -> tf .Tensor :
297
+ if self .alpha_parameter is None :
298
+ raise RuntimeError ("alpha is not initialized yet. Call build()." )
299
+ if callable (self .alpha_parameter ):
300
+ return self .alpha_parameter ()
301
+ return tf .convert_to_tensor (self .alpha_parameter , dtype = self .dtype )
302
+
204
303
@property
205
304
def beta (self ) -> tf .Tensor :
206
305
if self .beta_parameter is None :
@@ -217,6 +316,14 @@ def gamma(self) -> tf.Tensor:
217
316
return tf .convert_to_tensor (self .gamma_parameter (), dtype = self .dtype )
218
317
return self .gamma_parameter
219
318
319
+ @property
320
+ def epsilon (self ) -> tf .Tensor :
321
+ if self .epsilon_parameter is None :
322
+ raise RuntimeError ("epsilon is not initialized yet. Call build()." )
323
+ if callable (self .epsilon_parameter ):
324
+ return self .epsilon_parameter ()
325
+ return tf .convert_to_tensor (self .epsilon_parameter , dtype = self .dtype )
326
+
220
327
@property
221
328
def _channel_axis (self ):
222
329
return {"channels_first" : 1 , "channels_last" : - 1 }[self .data_format ]
@@ -231,6 +338,12 @@ def build(self, input_shape):
231
338
self .input_spec = tf .keras .layers .InputSpec (
232
339
min_ndim = 2 , axes = {channel_axis : num_channels })
233
340
341
+ if self .alpha_parameter is None :
342
+ initial_value = self .alpha_initializer (
343
+ shape = [], dtype = self .dtype )
344
+ self .alpha_parameter = parameters .GDNParameter (
345
+ initial_value , name = "alpha" , minimum = 1 )
346
+
234
347
if self .beta_parameter is None :
235
348
initial_value = self .beta_initializer (
236
349
shape = [num_channels ], dtype = self .dtype )
@@ -243,6 +356,12 @@ def build(self, input_shape):
243
356
self .gamma_parameter = parameters .GDNParameter (
244
357
initial_value , name = "gamma" , minimum = 0 )
245
358
359
+ if self .epsilon_parameter is None :
360
+ initial_value = self .epsilon_initializer (
361
+ shape = [], dtype = self .dtype )
362
+ self .epsilon_parameter = parameters .GDNParameter (
363
+ initial_value , name = "epsilon" , minimum = 1e-6 )
364
+
246
365
super ().build (input_shape )
247
366
248
367
def call (self , inputs ) -> tf .Tensor :
@@ -254,49 +373,88 @@ def call(self, inputs) -> tf.Tensor:
254
373
if self .rectify :
255
374
inputs = tf .nn .relu (inputs )
256
375
376
+ # Optimize for fixed alphas.
377
+ if not callable (self .alpha_parameter ) and self .alpha == 1 and self .rectify :
378
+ norm_pool = inputs
379
+ elif not callable (self .alpha_parameter ) and self .alpha == 1 :
380
+ norm_pool = abs (inputs )
381
+ elif not callable (self .alpha_parameter ) and self .alpha == 2 :
382
+ norm_pool = tf .math .square (inputs )
383
+ else :
384
+ norm_pool = inputs ** self .alpha
385
+
257
386
# Compute normalization pool.
258
387
if rank == 2 :
259
- norm_pool = tf .linalg .matmul (tf . math . square ( inputs ) , self .gamma )
388
+ norm_pool = tf .linalg .matmul (norm_pool , self .gamma )
260
389
norm_pool = tf .nn .bias_add (norm_pool , self .beta )
261
390
elif self .data_format == "channels_last" and rank <= 5 :
262
391
shape = self .gamma .shape
263
392
gamma = tf .reshape (self .gamma , (rank - 2 ) * [1 ] + shape )
264
- norm_pool = tf .nn .convolution (
265
- tf .math .square (inputs ), gamma , padding = "VALID" )
393
+ norm_pool = tf .nn .convolution (norm_pool , gamma , padding = "VALID" )
266
394
norm_pool = tf .nn .bias_add (norm_pool , self .beta )
267
395
else : # generic implementation
268
396
# This puts channels in the last dimension regardless of input.
269
397
norm_pool = tf .linalg .tensordot (
270
- tf . math . square ( inputs ) , self .gamma , [[self ._channel_axis ], [0 ]])
398
+ norm_pool , self .gamma , [[self ._channel_axis ], [0 ]])
271
399
norm_pool += self .beta
272
400
if self .data_format == "channels_first" :
273
401
# Return to channels_first format if necessary.
274
402
axes = list (range (rank - 1 ))
275
403
axes .insert (1 , rank - 1 )
276
404
norm_pool = tf .transpose (norm_pool , axes )
277
405
278
- if self .inverse :
406
+ # Optimize for fixed epsilons.
407
+ if not callable (self .epsilon_parameter ) and self .epsilon == 1 :
408
+ pass
409
+ elif not callable (self .epsilon_parameter ) and self .epsilon == .5 :
279
410
norm_pool = tf .math .sqrt (norm_pool )
280
411
else :
281
- norm_pool = tf .math .rsqrt (norm_pool )
282
- return inputs * norm_pool
412
+ norm_pool = norm_pool ** self .epsilon
413
+
414
+ if self .inverse :
415
+ return inputs * norm_pool
416
+ else :
417
+ return inputs / norm_pool
283
418
284
419
def compute_output_shape (self , input_shape ) -> tf .TensorShape :
285
420
return tf .TensorShape (input_shape )
286
421
287
422
def get_config (self ) -> Dict [str , Any ]:
288
423
config = super ().get_config ()
424
+
425
+ # Since alpha and epsilon are scalar, allow fixed values to be serialized.
426
+ def try_serialize (parameter , name ):
427
+ try :
428
+ return tf .keras .utils .serialize_keras_object (parameter )
429
+ except (ValueError , TypeError ): # Should throw TypeError, but doesn't...
430
+ try :
431
+ return float (parameter )
432
+ except TypeError :
433
+ raise TypeError (
434
+ f"Can't serialize { name } of type '{ type (parameter )} '." )
435
+
436
+ alpha_parameter = try_serialize (
437
+ self .alpha_parameter , "alpha_parameter" )
438
+ epsilon_parameter = try_serialize (
439
+ self .epsilon_parameter , "epsilon_parameter" )
440
+
289
441
config .update (
290
442
inverse = self .inverse ,
291
443
rectify = self .rectify ,
292
444
data_format = self .data_format ,
445
+ alpha_parameter = alpha_parameter ,
293
446
beta_parameter = tf .keras .utils .serialize_keras_object (
294
447
self .beta_parameter ),
295
448
gamma_parameter = tf .keras .utils .serialize_keras_object (
296
449
self .gamma_parameter ),
450
+ epsilon_parameter = epsilon_parameter ,
451
+ alpha_initializer = tf .keras .initializers .serialize (
452
+ self .alpha_initializer ),
297
453
beta_initializer = tf .keras .initializers .serialize (
298
454
self .beta_initializer ),
299
455
gamma_initializer = tf .keras .initializers .serialize (
300
456
self .gamma_initializer ),
457
+ epsilon_initializer = tf .keras .initializers .serialize (
458
+ self .epsilon_initializer ),
301
459
)
302
460
return config
0 commit comments