Skip to content

Commit 3b15a46

Browse files
author
jballe
committed
Refactoring of Python code.
- Minimizes deep TensorFlow imports. - Makes layer classes inherit from tf.keras.layers rather than tf.layers. - Cleans up some of the Keras layer interfaces. - Implements simultaneous up- and downsampling in SignalConv*. PiperOrigin-RevId: 241952245
1 parent d9dbba5 commit 3b15a46

File tree

7 files changed

+548
-539
lines changed

7 files changed

+548
-539
lines changed

python/layers/gdn.py

Lines changed: 104 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,7 @@
2121

2222
# Dependency imports
2323

24-
from tensorflow.python.eager import context
25-
from tensorflow.python.framework import ops
26-
from tensorflow.python.framework import tensor_shape
27-
from tensorflow.python.layers import base
28-
from tensorflow.python.ops import array_ops
29-
from tensorflow.python.ops import init_ops
30-
from tensorflow.python.ops import math_ops
31-
from tensorflow.python.ops import nn
24+
import tensorflow as tf
3225

3326
from tensorflow_compression.python.layers import parameterizers
3427

@@ -38,7 +31,7 @@
3831
_default_gamma_param = parameterizers.NonnegativeParameterizer()
3932

4033

41-
class GDN(base.Layer):
34+
class GDN(tf.keras.layers.Layer):
4235
"""Generalized divisive normalization layer.
4336
4437
Based on the papers:
@@ -69,30 +62,44 @@ class GDN(base.Layer):
6962
the division is replaced by multiplication).
7063
rectify: Boolean. If `True`, apply a `relu` nonlinearity to the inputs
7164
before calculating GDN response.
72-
gamma_init: The gamma matrix will be initialized as the identity matrix
73-
multiplied with this value. If set to zero, the layer is effectively
74-
initialized to the identity operation, since beta is initialized as one.
75-
A good default setting is somewhere between 0 and 0.5.
65+
gamma_init: Float. The gamma matrix will be initialized as the identity
66+
matrix multiplied with this value. If set to zero, the layer is
67+
effectively initialized to the identity operation, since beta is
68+
initialized as one. A good default setting is somewhere between 0 and 0.5.
7669
data_format: Format of input tensor. Currently supports `'channels_first'`
7770
and `'channels_last'`.
78-
beta_parameterizer: Reparameterization for beta parameter. Defaults to
79-
`NonnegativeParameterizer` with a minimum value of `1e-6`.
80-
gamma_parameterizer: Reparameterization for gamma parameter. Defaults to
81-
`NonnegativeParameterizer` with a minimum value of `0`.
71+
beta_parameterizer: `Parameterizer` object for beta parameter. Defaults
72+
to `NonnegativeParameterizer` with a minimum value of 1e-6.
73+
gamma_parameterizer: `Parameterizer` object for gamma parameter.
74+
Defaults to `NonnegativeParameterizer` with a minimum value of 0.
8275
activity_regularizer: Regularizer function for the output.
83-
trainable: Boolean, if `True`, also add variables to the graph collection
84-
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
85-
name: String, the name of the layer. Layers with the same name will
86-
share weights, but to avoid mistakes we require `reuse=True` in such
87-
cases.
76+
trainable: Boolean. Whether the layer should be trained.
77+
name: String. The name of the layer.
78+
dtype: `DType` of the layer's inputs (default of `None` means use the type
79+
of the first input).
8880
89-
Properties:
81+
Read-only properties:
9082
inverse: Boolean, whether GDN is computed (`True`) or IGDN (`False`).
9183
rectify: Boolean, whether to apply `relu` before normalization or not.
92-
data_format: Format of input tensor. Currently supports `'channels_first'`
93-
and `'channels_last'`.
84+
gamma_init: See above.
85+
data_format: See above.
86+
activity_regularizer: See above.
87+
name: See above.
88+
dtype: See above.
9489
beta: The beta parameter as defined above (1D `Tensor`).
9590
gamma: The gamma parameter as defined above (2D `Tensor`).
91+
trainable_variables: List of trainable variables.
92+
non_trainable_variables: List of non-trainable variables.
93+
variables: List of all variables of this layer, trainable and non-trainable.
94+
updates: List of update ops of this layer.
95+
losses: List of losses added by this layer.
96+
97+
Mutable properties:
98+
beta_parameterizer: See above.
99+
gamma_parameterizer: See above.
100+
trainable: Boolean. Whether the layer should be trained.
101+
input_spec: Optional `InputSpec` object specifying the constraints on inputs
102+
that can be accepted by the layer.
96103
"""
97104

98105
def __init__(self,
@@ -102,87 +109,122 @@ def __init__(self,
102109
data_format="channels_last",
103110
beta_parameterizer=_default_beta_param,
104111
gamma_parameterizer=_default_gamma_param,
105-
activity_regularizer=None,
106-
trainable=True,
107-
name=None,
108112
**kwargs):
109-
super(GDN, self).__init__(trainable=trainable, name=name,
110-
activity_regularizer=activity_regularizer,
111-
**kwargs)
112-
self.inverse = bool(inverse)
113-
self.rectify = bool(rectify)
113+
super(GDN, self).__init__(**kwargs)
114+
self._inverse = bool(inverse)
115+
self._rectify = bool(rectify)
114116
self._gamma_init = float(gamma_init)
115-
self.data_format = data_format
117+
self._data_format = str(data_format)
116118
self._beta_parameterizer = beta_parameterizer
117119
self._gamma_parameterizer = gamma_parameterizer
118-
self._channel_axis() # trigger ValueError early
119-
self.input_spec = base.InputSpec(min_ndim=2)
120+
121+
if self.data_format not in ("channels_first", "channels_last"):
122+
raise ValueError("Unknown data format: '{}'.".format(self.data_format))
123+
124+
self.input_spec = tf.keras.layers.InputSpec(min_ndim=2)
125+
126+
@property
127+
def inverse(self):
128+
return self._inverse
129+
130+
@property
131+
def rectify(self):
132+
return self._rectify
133+
134+
@property
135+
def gamma_init(self):
136+
return self._gamma_init
137+
138+
@property
139+
def data_format(self):
140+
return self._data_format
141+
142+
@property
143+
def beta_parameterizer(self):
144+
return self._beta_parameterizer
145+
146+
@beta_parameterizer.setter
147+
def beta_parameterizer(self, val):
148+
if self.built:
149+
raise RuntimeError(
150+
"Can't set `beta_parameterizer` once layer has been built.")
151+
self._beta_parameterizer = val
152+
153+
@property
154+
def gamma_parameterizer(self):
155+
return self._gamma_parameterizer
156+
157+
@gamma_parameterizer.setter
158+
def gamma_parameterizer(self, val):
159+
if self.built:
160+
raise RuntimeError(
161+
"Can't set `gamma_parameterizer` once layer has been built.")
162+
self._gamma_parameterizer = val
120163

121164
def _channel_axis(self):
122-
try:
123-
return {"channels_first": 1, "channels_last": -1}[self.data_format]
124-
except KeyError:
125-
raise ValueError("Unsupported `data_format` for GDN layer: {}.".format(
126-
self.data_format))
165+
return {"channels_first": 1, "channels_last": -1}[self.data_format]
127166

128167
def build(self, input_shape):
129168
channel_axis = self._channel_axis()
130-
input_shape = tensor_shape.TensorShape(input_shape)
169+
input_shape = tf.TensorShape(input_shape)
131170
num_channels = input_shape[channel_axis].value
132171
if num_channels is None:
133172
raise ValueError("The channel dimension of the inputs to `GDN` "
134173
"must be defined.")
135174
self._input_rank = input_shape.ndims
136-
self.input_spec = base.InputSpec(ndim=input_shape.ndims,
137-
axes={channel_axis: num_channels})
175+
self.input_spec = tf.keras.layers.InputSpec(
176+
ndim=input_shape.ndims, axes={channel_axis: num_channels})
138177

139-
self.beta = self._beta_parameterizer(
178+
# Sorry, lint, but these objects really are callable ...
179+
# pylint:disable=not-callable
180+
self.beta = self.beta_parameterizer(
140181
name="beta", shape=[num_channels], dtype=self.dtype,
141-
getter=self.add_variable, initializer=init_ops.Ones())
182+
getter=self.add_variable, initializer=tf.initializers.ones())
142183

143-
self.gamma = self._gamma_parameterizer(
184+
self.gamma = self.gamma_parameterizer(
144185
name="gamma", shape=[num_channels, num_channels], dtype=self.dtype,
145186
getter=self.add_variable,
146-
initializer=init_ops.Identity(gain=self._gamma_init))
187+
initializer=tf.initializers.identity(gain=self._gamma_init))
188+
# pylint:enable=not-callable
147189

148190
self.built = True
149191

150192
def call(self, inputs):
151-
inputs = ops.convert_to_tensor(inputs, dtype=self.dtype)
193+
inputs = tf.convert_to_tensor(inputs, dtype=self.dtype)
152194
ndim = self._input_rank
153195

154196
if self.rectify:
155-
inputs = nn.relu(inputs)
197+
inputs = tf.nn.relu(inputs)
156198

157199
# Compute normalization pool.
158200
if ndim == 2:
159-
norm_pool = math_ops.matmul(math_ops.square(inputs), self.gamma)
160-
norm_pool = nn.bias_add(norm_pool, self.beta)
201+
norm_pool = tf.linalg.matmul(tf.math.square(inputs), self.gamma)
202+
norm_pool = tf.nn.bias_add(norm_pool, self.beta)
161203
elif self.data_format == "channels_last" and ndim <= 5:
162204
shape = self.gamma.shape.as_list()
163-
gamma = array_ops.reshape(self.gamma, (ndim - 2) * [1] + shape)
164-
norm_pool = nn.convolution(math_ops.square(inputs), gamma, "VALID")
165-
norm_pool = nn.bias_add(norm_pool, self.beta)
205+
gamma = tf.reshape(self.gamma, (ndim - 2) * [1] + shape)
206+
norm_pool = tf.nn.convolution(tf.math.square(inputs), gamma, "VALID")
207+
norm_pool = tf.nn.bias_add(norm_pool, self.beta)
166208
else: # generic implementation
167209
# This puts channels in the last dimension regardless of input.
168-
norm_pool = math_ops.tensordot(
169-
math_ops.square(inputs), self.gamma, [[self._channel_axis()], [0]])
210+
norm_pool = tf.linalg.tensordot(
211+
tf.math.square(inputs), self.gamma, [[self._channel_axis()], [0]])
170212
norm_pool += self.beta
171213
if self.data_format == "channels_first":
172214
# Return to channels_first format if necessary.
173215
axes = list(range(ndim - 1))
174216
axes.insert(1, ndim - 1)
175-
norm_pool = array_ops.transpose(norm_pool, axes)
217+
norm_pool = tf.transpose(norm_pool, axes)
176218

177219
if self.inverse:
178-
norm_pool = math_ops.sqrt(norm_pool)
220+
norm_pool = tf.math.sqrt(norm_pool)
179221
else:
180-
norm_pool = math_ops.rsqrt(norm_pool)
222+
norm_pool = tf.math.rsqrt(norm_pool)
181223
outputs = inputs * norm_pool
182224

183-
if not context.executing_eagerly():
225+
if not tf.executing_eagerly():
184226
outputs.set_shape(self.compute_output_shape(inputs.shape))
185227
return outputs
186228

187229
def compute_output_shape(self, input_shape):
188-
return tensor_shape.TensorShape(input_shape)
230+
return tf.TensorShape(input_shape)

python/layers/initializers.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121

2222
# Dependency imports
2323

24-
from tensorflow.python.ops import array_ops
25-
from tensorflow.python.ops import linalg_ops
24+
import tensorflow as tf
2625

2726

2827
class IdentityInitializer(object):
@@ -45,11 +44,11 @@ def __call__(self, shape, dtype=None, partition_info=None):
4544

4645
support = tuple(shape[:-2]) + (1, 1)
4746
indices = [[s // 2 for s in support]]
48-
updates = array_ops.constant([self.gain], dtype=dtype)
49-
kernel = array_ops.scatter_nd(indices, updates, support)
47+
updates = tf.constant([self.gain], dtype=dtype)
48+
kernel = tf.scatter_nd(indices, updates, support)
5049

5150
assert shape[-2] == shape[-1], shape
5251
if shape[-1] != 1:
53-
kernel *= linalg_ops.eye(shape[-1], dtype=dtype)
52+
kernel *= tf.eye(shape[-1], dtype=dtype)
5453

5554
return kernel

python/layers/parameterizers.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,39 +13,38 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
# ==============================================================================
16-
"""Parameterizations for layer classes."""
16+
"""Parameterizers for layer classes."""
1717

1818
from __future__ import absolute_import
1919
from __future__ import division
2020
from __future__ import print_function
2121

2222
# Dependency imports
2323

24-
from tensorflow.python.ops import array_ops
25-
from tensorflow.python.ops import math_ops
24+
import tensorflow as tf
2625

27-
from tensorflow_compression.python.ops import math_ops as cmath_ops
28-
from tensorflow_compression.python.ops import spectral_ops as spectral_ops
26+
from tensorflow_compression.python.ops import math_ops
27+
from tensorflow_compression.python.ops import spectral_ops
2928

3029

3130
class Parameterizer(object):
32-
"""Parameterizer object (abstract base class).
31+
"""Parameterization object (abstract base class).
3332
34-
Parameterizer objects are immutable objects designed to facilitate
33+
`Parameterizer`s are immutable objects designed to facilitate
3534
reparameterization of model parameters (tensor variables). They are called
3635
just like `tf.get_variable` with an additional argument `getter` specifying
3736
the actual function call to generate a variable (in many cases, `getter` would
3837
be `tf.get_variable`).
3938
40-
To achieve reparameterization, a parameterizer object wraps the provided
41-
initializer, regularizer, and the returned variable in its own Tensorflow
39+
To achieve reparameterization, a `Parameterizer` wraps the provided
40+
initializer, regularizer, and the returned variable in its own TensorFlow
4241
code.
4342
"""
4443
pass
4544

4645

4746
class StaticParameterizer(Parameterizer):
48-
"""A parameterization object that always returns a constant tensor.
47+
"""A parameterizer that always returns a constant tensor.
4948
5049
No variables are created, hence the parameter never changes.
5150
@@ -102,13 +101,13 @@ def rdft_initializer(shape, dtype=None, partition_info=None):
102101
assert dtype == rdft_dtype, dtype
103102
init = initializer(
104103
var_shape, dtype=var_dtype, partition_info=partition_info)
105-
init = array_ops.reshape(init, (-1, rdft_shape[-1]))
106-
init = math_ops.matmul(irdft_matrix, init, transpose_a=True)
104+
init = tf.reshape(init, (-1, rdft_shape[-1]))
105+
init = tf.linalg.matmul(irdft_matrix, init, transpose_a=True)
107106
return init
108107

109108
def reparam(rdft):
110-
var = math_ops.matmul(irdft_matrix, rdft)
111-
var = array_ops.reshape(var, var_shape)
109+
var = tf.linalg.matmul(irdft_matrix, rdft)
110+
var = tf.reshape(var, var_shape)
112111
return var
113112

114113
if regularizer is not None:
@@ -129,7 +128,7 @@ class NonnegativeParameterizer(Parameterizer):
129128
Args:
130129
minimum: Float. Lower bound for parameters (defaults to zero).
131130
reparam_offset: Float. Offset added to the reparameterization of beta and
132-
gamma. The reparameterization of beta and gamma as their square roots lets
131+
gamma. The parameterization of beta and gamma as their square roots lets
133132
the training slow down when their values are close to zero, which is
134133
desirable as small values in the denominator can lead to a situation where
135134
gradient noise on beta/gamma leads to extreme amounts of noise in the GDN
@@ -146,23 +145,23 @@ def __init__(self, minimum=0, reparam_offset=2 ** -18):
146145
self.reparam_offset = float(reparam_offset)
147146

148147
def __call__(self, getter, name, shape, dtype, initializer, regularizer=None):
149-
pedestal = array_ops.constant(self.reparam_offset ** 2, dtype=dtype)
150-
bound = array_ops.constant(
148+
pedestal = tf.constant(self.reparam_offset ** 2, dtype=dtype)
149+
bound = tf.constant(
151150
(self.minimum + self.reparam_offset ** 2) ** .5, dtype=dtype)
152151
reparam_name = "reparam_" + name
153152

154153
def reparam_initializer(shape, dtype=None, partition_info=None):
155154
init = initializer(shape, dtype=dtype, partition_info=partition_info)
156-
init = math_ops.sqrt(init + pedestal)
155+
init = tf.math.sqrt(tf.math.maximum(init + pedestal, pedestal))
157156
return init
158157

159158
def reparam(var):
160-
var = cmath_ops.lower_bound(var, bound)
161-
var = math_ops.square(var) - pedestal
159+
var = math_ops.lower_bound(var, bound)
160+
var = tf.math.square(var) - pedestal
162161
return var
163162

164163
if regularizer is not None:
165-
regularizer = lambda rdft: regularizer(reparam(rdft))
164+
regularizer = lambda var: regularizer(reparam(var))
166165

167166
var = getter(
168167
name=reparam_name, shape=shape, dtype=dtype,

0 commit comments

Comments
 (0)