Skip to content

Commit e0e4c09

Browse files
committed
Simplification of clustering registry.
Added test for DepthwiseConv2D. Change-Id: Ida1598ac91e2044b3a13fd8fd0cf5a6f3279133c
1 parent d668624 commit e0e4c09

File tree

4 files changed

+99
-70
lines changed

4 files changed

+99
-70
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def setUp(self):
6565
self.keras_unsupported_layer = layers.ConvLSTM2D(2, (5, 5)) # Unsupported
6666
self.custom_clusterable_layer = CustomClusterableLayer(10)
6767
self.custom_non_clusterable_layer = CustomNonClusterableLayer(10)
68+
self.keras_depthwiseconv2d_layer = layers.DepthwiseConv2D((3, 3), (1, 1))
6869

6970
clustering_registry.ClusteringLookupRegistry.register_new_implementation(
7071
{
@@ -81,10 +82,10 @@ def setUp(self):
8182
cluster_config.CentroidInitialization.DENSITY_BASED
8283
}
8384

84-
def _build_clustered_layer_model(self, layer):
85+
def _build_clustered_layer_model(self, layer, input_shape=(10, 1)):
8586
wrapped_layer = cluster.cluster_weights(layer, **self.params)
8687
self.model.add(wrapped_layer)
87-
self.model.build(input_shape=(10, 1))
88+
self.model.build(input_shape=input_shape)
8889

8990
return wrapped_layer
9091

@@ -124,13 +125,32 @@ def testClusterKerasNonClusterableLayer(self):
124125
wrapped_layer)
125126
self.assertEqual([], wrapped_layer.layer.get_clusterable_weights())
126127

128+
@keras_parameterized.run_all_keras_modes
129+
def testDepthwiseConv2DLayerNonClusterable(self):
130+
"""
131+
Verifies that we don't cluster a DepthwiseConv2D layer,
132+
because clustering of this type of layer gives
133+
big unrecoverable accuracy loss.
134+
"""
135+
wrapped_layer = self._build_clustered_layer_model(
136+
self.keras_depthwiseconv2d_layer,
137+
input_shape=(1, 10, 10, 10)
138+
)
139+
140+
self._validate_clustered_layer(self.keras_depthwiseconv2d_layer,
141+
wrapped_layer)
142+
self.assertEqual([], wrapped_layer.layer.get_clusterable_weights())
143+
127144
def testClusterKerasUnsupportedLayer(self):
128145
"""
129146
Verifies that attempting to cluster an unsupported layer raises an
130147
exception.
131148
"""
149+
keras_unsupported_layer = self.keras_unsupported_layer
150+
# We need to build weights before check.
151+
keras_unsupported_layer.build(input_shape = (10, 10))
132152
with self.assertRaises(ValueError):
133-
cluster.cluster_weights(self.keras_unsupported_layer, **self.params)
153+
cluster.cluster_weights(keras_unsupported_layer, **self.params)
134154

135155
@keras_parameterized.run_all_keras_modes
136156
def testClusterCustomClusterableLayer(self):
@@ -149,8 +169,11 @@ def testClusterCustomNonClusterableLayer(self):
149169
Verifies that attempting to cluster a custom non-clusterable layer raises
150170
an exception.
151171
"""
172+
custom_non_clusterable_layer = self.custom_non_clusterable_layer
173+
# We need to build weights before check.
174+
custom_non_clusterable_layer.build(input_shape=(10, 10))
152175
with self.assertRaises(ValueError):
153-
cluster_wrapper.ClusterWeights(self.custom_non_clusterable_layer,
176+
cluster_wrapper.ClusterWeights(custom_non_clusterable_layer,
154177
**self.params)
155178

156179
@keras_parameterized.run_all_keras_modes
@@ -206,11 +229,14 @@ def testClusterModelUnsupportedKerasLayerRaisesError(self):
206229
Verifies that attempting to cluster a model that contains an unsupported
207230
layer raises an exception.
208231
"""
232+
keras_unsupported_layer = self.keras_unsupported_layer
233+
# We need to build weights before check.
234+
keras_unsupported_layer.build(input_shape = (10, 10))
209235
with self.assertRaises(ValueError):
210236
cluster.cluster_weights(
211237
keras.Sequential([
212238
self.keras_clusterable_layer, self.keras_non_clusterable_layer,
213-
self.custom_clusterable_layer, self.keras_unsupported_layer
239+
self.custom_clusterable_layer, keras_unsupported_layer
214240
]), **self.params)
215241

216242
def testClusterModelCustomNonClusterableLayerRaisesError(self):
@@ -219,10 +245,13 @@ def testClusterModelCustomNonClusterableLayerRaisesError(self):
219245
non-clusterable layer raises an exception.
220246
"""
221247
with self.assertRaises(ValueError):
248+
custom_non_clusterable_layer = self.custom_non_clusterable_layer
249+
# We need to build weights before check.
250+
custom_non_clusterable_layer.build(input_shape = (1, 2))
222251
cluster.cluster_weights(
223252
keras.Sequential([
224253
self.keras_clusterable_layer, self.keras_non_clusterable_layer,
225-
self.custom_clusterable_layer, self.custom_non_clusterable_layer
254+
self.custom_clusterable_layer, custom_non_clusterable_layer
226255
]), **self.params)
227256

228257
@keras_parameterized.run_all_keras_modes

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,16 @@
3333
CentroidInitialization = cluster_config.CentroidInitialization
3434

3535

36-
class NonClusterableLayer(layers.Dense):
37-
"""A custom layer that is not clusterable."""
38-
36+
class NonClusterableLayer(layers.Layer):
37+
""""A custom layer with weights that is not clusterable."""
38+
def __init__(self, units=10):
39+
super(NonClusterableLayer, self).__init__()
40+
self.add_weight(shape=(1, units),
41+
initializer='uniform',
42+
name='kernel')
43+
44+
def call(self, inputs):
45+
return tf.matmul(inputs, self.weights)
3946

4047
class AlreadyClusterableLayer(layers.Dense, clusterable_layer.ClusterableLayer):
4148
"""A custom layer that is clusterable."""

tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py

Lines changed: 10 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import abc
1818
import six
1919
import tensorflow as tf
20-
2120
from tensorflow.keras import layers
2221

2322
from tensorflow_model_optimization.python.core.clustering.keras import clusterable_layer
@@ -230,69 +229,21 @@ class ClusteringRegistry(object):
230229
# the variables within the layers which hold the kernel weights. This
231230
# allows the wrapper to access and modify the weights.
232231
_LAYERS_WEIGHTS_MAP = {
233-
layers.ELU: [],
234-
layers.LeakyReLU: [],
235-
layers.ReLU: [],
236-
layers.Softmax: [],
237-
layers.ThresholdedReLU: [],
238232
layers.Conv1D: ['kernel'],
239233
layers.Conv2D: ['kernel'],
240234
layers.Conv2DTranspose: ['kernel'],
241235
layers.Conv3D: ['kernel'],
242236
layers.Conv3DTranspose: ['kernel'],
243-
layers.Cropping1D: [],
244-
layers.Cropping2D: [],
245-
layers.Cropping3D: [],
237+
# non-clusterable due to big unrecoverable accuracy loss
246238
layers.DepthwiseConv2D: [],
247239
layers.SeparableConv1D: ['pointwise_kernel'],
248240
layers.SeparableConv2D: ['pointwise_kernel'],
249-
layers.UpSampling1D: [],
250-
layers.UpSampling2D: [],
251-
layers.UpSampling3D: [],
252-
layers.ZeroPadding1D: [],
253-
layers.ZeroPadding2D: [],
254-
layers.ZeroPadding3D: [],
255-
layers.Activation: [],
256-
layers.ActivityRegularization: [],
257241
layers.Dense: ['kernel'],
258-
layers.Dropout: [],
259-
layers.Flatten: [],
260-
layers.Lambda: [],
261-
layers.Masking: [],
262-
layers.Permute: [],
263-
layers.RepeatVector: [],
264-
layers.Reshape: [],
265-
layers.SpatialDropout1D: [],
266-
layers.SpatialDropout2D: [],
267-
layers.SpatialDropout3D: [],
268242
layers.Embedding: ['embeddings'],
269243
layers.LocallyConnected1D: ['kernel'],
270244
layers.LocallyConnected2D: ['kernel'],
271-
layers.Add: [],
272-
layers.Average: [],
273-
layers.Concatenate: [],
274-
layers.Dot: [],
275-
layers.Maximum: [],
276-
layers.Minimum: [],
277-
layers.Multiply: [],
278-
layers.Subtract: [],
279-
layers.AlphaDropout: [],
280-
layers.GaussianDropout: [],
281-
layers.GaussianNoise: [],
282245
layers.BatchNormalization: [],
283246
layers.LayerNormalization: [],
284-
layers.AveragePooling1D: [],
285-
layers.AveragePooling2D: [],
286-
layers.AveragePooling3D: [],
287-
layers.GlobalAveragePooling1D: [],
288-
layers.GlobalAveragePooling2D: [],
289-
layers.GlobalAveragePooling3D: [],
290-
layers.GlobalMaxPooling1D: [],
291-
layers.GlobalMaxPooling2D: [],
292-
layers.GlobalMaxPooling3D: [],
293-
layers.MaxPooling1D: [],
294-
layers.MaxPooling2D: [],
295-
layers.MaxPooling3D: [],
296247
}
297248

298249
_RNN_CELLS_WEIGHTS_MAP = {
@@ -339,6 +290,11 @@ def supports(cls, layer):
339290
True/False whether the layer type is supported.
340291
341292
"""
293+
# Automatically enable layers with zero trainable weights.
294+
# Example: Reshape, AveragePooling2D, Maximum/Minimum, etc.
295+
if len(layer.trainable_weights) == 0:
296+
return True
297+
342298
if layer.__class__ in cls._LAYERS_WEIGHTS_MAP:
343299
return True
344300

@@ -363,6 +319,10 @@ def _is_rnn_layer(cls, layer):
363319

364320
@classmethod
365321
def _weight_names(cls, layer):
322+
# For layers with zero trainable weights, like Reshape, Pooling.
323+
if len(layer.trainable_weights) == 0:
324+
return []
325+
366326
return cls._LAYERS_WEIGHTS_MAP[layer.__class__]
367327

368328
@classmethod

tensorflow_model_optimization/python/core/clustering/keras/clustering_registry_test.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,14 @@ def testConvolutionalWeightsCA(self,
9292

9393
class CustomLayer(layers.Layer):
9494
"""A custom non-clusterable layer class."""
95+
def __init__(self, units=10):
96+
super(CustomLayer, self).__init__()
97+
self.add_weight(shape=(1, units),
98+
initializer='uniform',
99+
name='kernel')
95100

101+
def call(self, inputs):
102+
return tf.matmul(inputs, self.weights)
96103

97104
class ClusteringLookupRegistryTest(test.TestCase, parameterized.TestCase):
98105
"""Unit tests for the ClusteringLookupRegistry class."""
@@ -191,6 +198,11 @@ class CustomLayerFromClusterableLayer(layers.Dense):
191198
"""A custom layer class derived from a built-in clusterable layer."""
192199
pass
193200

201+
class CustomLayerFromClusterableLayerNoWeights(layers.Reshape):
202+
"""A custom layer class derived from a built-in clusterable layer,
203+
that does not have any weights."""
204+
pass
205+
194206
class MinimalRNNCell(keras.layers.Layer):
195207
"""A minimal RNN cell implementation."""
196208

@@ -250,7 +262,10 @@ def testDoesNotSupportKerasUnsupportedLayer(self):
250262
Verifies that ClusterRegistry does not support an unknown built-in layer.
251263
"""
252264
# ConvLSTM2D is a built-in keras layer but not supported.
253-
self.assertFalse(ClusterRegistry.supports(layers.ConvLSTM2D(2, (5, 5))))
265+
l = layers.ConvLSTM2D(2, (5, 5))
266+
# We need to build weights
267+
l.build(input_shape = (10, 10))
268+
self.assertFalse(ClusterRegistry.supports(l))
254269

255270
def testSupportsKerasRNNLayers(self):
256271
"""
@@ -265,8 +280,10 @@ def testDoesNotSupportKerasRNNLayerUnknownCell(self):
265280
Verifies that ClusterRegistry does not support a custom non-clusterable RNN
266281
cell.
267282
"""
268-
self.assertFalse(ClusterRegistry.supports(
269-
keras.layers.RNN(ClusterRegistryTest.MinimalRNNCell(32))))
283+
l = keras.layers.RNN(ClusterRegistryTest.MinimalRNNCell(32))
284+
# We need to build it to have weights
285+
l.build((10,1))
286+
self.assertFalse(ClusterRegistry.supports(l))
270287

271288
def testSupportsKerasRNNLayerClusterableCell(self):
272289
"""
@@ -285,19 +302,31 @@ def testDoesNotSupportCustomLayer(self):
285302
def testDoesNotSupportCustomLayerInheritedFromClusterableLayer(self):
286303
"""
287304
Verifies that ClusterRegistry does not support a custom layer derived from
288-
a clusterable layer.
305+
a clusterable layer if there are trainable weights.
306+
"""
307+
custom_layer = ClusterRegistryTest.CustomLayerFromClusterableLayer(10)
308+
custom_layer.build(input_shape=(10, 10))
309+
self.assertFalse(ClusterRegistry.supports(custom_layer))
310+
311+
def testSupportsCustomLayerInheritedFromClusterableLayerNoWeights(self):
312+
"""
313+
Verifies that ClusterRegistry supports a custom layer derived from
314+
a clusterable layer that does not have trainable weights.
289315
"""
290-
self.assertFalse(
291-
ClusterRegistry.supports(
292-
ClusterRegistryTest.CustomLayerFromClusterableLayer(10)))
316+
custom_layer = ClusterRegistryTest.\
317+
CustomLayerFromClusterableLayerNoWeights((7, 1))
318+
custom_layer.build(input_shape=(3, 4))
319+
self.assertTrue(ClusterRegistry.supports(custom_layer))
293320

294321
def testMakeClusterableRaisesErrorForKerasUnsupportedLayer(self):
295322
"""
296323
Verifies that an unsupported built-in layer cannot be made clusterable by
297324
calling make_clusterable().
298325
"""
326+
l = layers.ConvLSTM2D(2, (5, 5))
327+
l.build(input_shape = (10, 10))
299328
with self.assertRaises(ValueError):
300-
ClusterRegistry.make_clusterable(layers.ConvLSTM2D(2, (5, 5)))
329+
ClusterRegistry.make_clusterable(l)
301330

302331
def testMakeClusterableRaisesErrorForCustomLayer(self):
303332
"""
@@ -313,9 +342,10 @@ def testMakeClusterableRaisesErrorForCustomLayerInheritedFromClusterableLayer(
313342
Verifies that a non-clusterable layer derived from a clusterable layer
314343
cannot be made clusterable by calling make_clusterable().
315344
"""
345+
l = ClusterRegistryTest.CustomLayerFromClusterableLayer(10)
346+
l.build(input_shape = (10, 10))
316347
with self.assertRaises(ValueError):
317-
ClusterRegistry.make_clusterable(
318-
ClusterRegistryTest.CustomLayerFromClusterableLayer(10))
348+
ClusterRegistry.make_clusterable(l)
319349

320350
def testMakeClusterableWorksOnKerasClusterableLayer(self):
321351
"""
@@ -413,9 +443,12 @@ def testMakeClusterableRaisesErrorOnRNNLayersUnsupportedCell(self):
413443
Verifies that make_clusterable() raises an exception when invoked with a
414444
built-in RNN layer that contains a non-clusterable custom RNN cell.
415445
"""
446+
l = ClusterRegistryTest.MinimalRNNCell(5)
447+
# we need to build weights
448+
l.build(input_shape = (10, 1))
416449
with self.assertRaises(ValueError):
417450
ClusterRegistry.make_clusterable(layers.RNN(
418-
[layers.LSTMCell(10), ClusterRegistryTest.MinimalRNNCell(5)]))
451+
[layers.LSTMCell(10), l]))
419452

420453

421454
if __name__ == '__main__':

0 commit comments

Comments
 (0)