Skip to content

Commit 270435d

Browse files
Merge pull request #539 from wwwind:clustering_registry_improvement
PiperOrigin-RevId: 333336338
2 parents 87c06eb + 33a8030 commit 270435d

File tree

4 files changed

+102
-70
lines changed

4 files changed

+102
-70
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def setUp(self):
6565
self.keras_unsupported_layer = layers.ConvLSTM2D(2, (5, 5)) # Unsupported
6666
self.custom_clusterable_layer = CustomClusterableLayer(10)
6767
self.custom_non_clusterable_layer = CustomNonClusterableLayer(10)
68+
self.keras_depthwiseconv2d_layer = layers.DepthwiseConv2D((3, 3), (1, 1))
6869

6970
clustering_registry.ClusteringLookupRegistry.register_new_implementation(
7071
{
@@ -81,10 +82,10 @@ def setUp(self):
8182
cluster_config.CentroidInitialization.DENSITY_BASED
8283
}
8384

84-
def _build_clustered_layer_model(self, layer):
85+
def _build_clustered_layer_model(self, layer, input_shape=(10, 1)):
8586
wrapped_layer = cluster.cluster_weights(layer, **self.params)
8687
self.model.add(wrapped_layer)
87-
self.model.build(input_shape=(10, 1))
88+
self.model.build(input_shape=input_shape)
8889

8990
return wrapped_layer
9091

@@ -124,13 +125,32 @@ def testClusterKerasNonClusterableLayer(self):
124125
wrapped_layer)
125126
self.assertEqual([], wrapped_layer.layer.get_clusterable_weights())
126127

128+
@keras_parameterized.run_all_keras_modes
129+
def testDepthwiseConv2DLayerNonClusterable(self):
130+
"""
131+
Verifies that we don't cluster a DepthwiseConv2D layer,
132+
because clustering of this type of layer gives
133+
big unrecoverable accuracy loss.
134+
"""
135+
wrapped_layer = self._build_clustered_layer_model(
136+
self.keras_depthwiseconv2d_layer,
137+
input_shape=(1, 10, 10, 10)
138+
)
139+
140+
self._validate_clustered_layer(self.keras_depthwiseconv2d_layer,
141+
wrapped_layer)
142+
self.assertEqual([], wrapped_layer.layer.get_clusterable_weights())
143+
127144
def testClusterKerasUnsupportedLayer(self):
128145
"""
129146
Verifies that attempting to cluster an unsupported layer raises an
130147
exception.
131148
"""
149+
keras_unsupported_layer = self.keras_unsupported_layer
150+
# We need to build weights before check.
151+
keras_unsupported_layer.build(input_shape = (10, 10))
132152
with self.assertRaises(ValueError):
133-
cluster.cluster_weights(self.keras_unsupported_layer, **self.params)
153+
cluster.cluster_weights(keras_unsupported_layer, **self.params)
134154

135155
@keras_parameterized.run_all_keras_modes
136156
def testClusterCustomClusterableLayer(self):
@@ -149,8 +169,14 @@ def testClusterCustomNonClusterableLayer(self):
149169
Verifies that attempting to cluster a custom non-clusterable layer raises
150170
an exception.
151171
"""
172+
custom_non_clusterable_layer = self.custom_non_clusterable_layer
173+
# Once layer is empty with no weights allocated, clustering is supported.
174+
cluster_wrapper.ClusterWeights(custom_non_clusterable_layer,
175+
**self.params)
176+
# We need to build weights before check that clustering is not supported.
177+
custom_non_clusterable_layer.build(input_shape=(10, 10))
152178
with self.assertRaises(ValueError):
153-
cluster_wrapper.ClusterWeights(self.custom_non_clusterable_layer,
179+
cluster_wrapper.ClusterWeights(custom_non_clusterable_layer,
154180
**self.params)
155181

156182
@keras_parameterized.run_all_keras_modes
@@ -206,11 +232,14 @@ def testClusterModelUnsupportedKerasLayerRaisesError(self):
206232
Verifies that attempting to cluster a model that contains an unsupported
207233
layer raises an exception.
208234
"""
235+
keras_unsupported_layer = self.keras_unsupported_layer
236+
# We need to build weights before check.
237+
keras_unsupported_layer.build(input_shape = (10, 10))
209238
with self.assertRaises(ValueError):
210239
cluster.cluster_weights(
211240
keras.Sequential([
212241
self.keras_clusterable_layer, self.keras_non_clusterable_layer,
213-
self.custom_clusterable_layer, self.keras_unsupported_layer
242+
self.custom_clusterable_layer, keras_unsupported_layer
214243
]), **self.params)
215244

216245
def testClusterModelCustomNonClusterableLayerRaisesError(self):
@@ -219,10 +248,13 @@ def testClusterModelCustomNonClusterableLayerRaisesError(self):
219248
non-clusterable layer raises an exception.
220249
"""
221250
with self.assertRaises(ValueError):
251+
custom_non_clusterable_layer = self.custom_non_clusterable_layer
252+
# We need to build weights before check.
253+
custom_non_clusterable_layer.build(input_shape = (1, 2))
222254
cluster.cluster_weights(
223255
keras.Sequential([
224256
self.keras_clusterable_layer, self.keras_non_clusterable_layer,
225-
self.custom_clusterable_layer, self.custom_non_clusterable_layer
257+
self.custom_clusterable_layer, custom_non_clusterable_layer
226258
]), **self.params)
227259

228260
@keras_parameterized.run_all_keras_modes

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,16 @@
3232
CentroidInitialization = cluster_config.CentroidInitialization
3333

3434

35-
class NonClusterableLayer(layers.Dense):
36-
"""A custom layer that is not clusterable."""
37-
35+
class NonClusterableLayer(layers.Layer):
36+
""""A custom layer with weights that is not clusterable."""
37+
def __init__(self, units=10):
38+
super(NonClusterableLayer, self).__init__()
39+
self.add_weight(shape=(1, units),
40+
initializer='uniform',
41+
name='kernel')
42+
43+
def call(self, inputs):
44+
return tf.matmul(inputs, self.weights)
3845

3946
class AlreadyClusterableLayer(layers.Dense, clusterable_layer.ClusterableLayer):
4047
"""A custom layer that is clusterable."""

tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py

Lines changed: 10 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import abc
1818
import six
1919
import tensorflow as tf
20-
2120
from tensorflow.keras import layers
2221

2322
from tensorflow_model_optimization.python.core.clustering.keras import clusterable_layer
@@ -260,69 +259,21 @@ class ClusteringRegistry(object):
260259
# the variables within the layers which hold the kernel weights. This
261260
# allows the wrapper to access and modify the weights.
262261
_LAYERS_WEIGHTS_MAP = {
263-
layers.ELU: [],
264-
layers.LeakyReLU: [],
265-
layers.ReLU: [],
266-
layers.Softmax: [],
267-
layers.ThresholdedReLU: [],
268262
layers.Conv1D: ['kernel'],
269263
layers.Conv2D: ['kernel'],
270264
layers.Conv2DTranspose: ['kernel'],
271265
layers.Conv3D: ['kernel'],
272266
layers.Conv3DTranspose: ['kernel'],
273-
layers.Cropping1D: [],
274-
layers.Cropping2D: [],
275-
layers.Cropping3D: [],
267+
# non-clusterable due to big unrecoverable accuracy loss
276268
layers.DepthwiseConv2D: [],
277269
layers.SeparableConv1D: ['pointwise_kernel'],
278270
layers.SeparableConv2D: ['pointwise_kernel'],
279-
layers.UpSampling1D: [],
280-
layers.UpSampling2D: [],
281-
layers.UpSampling3D: [],
282-
layers.ZeroPadding1D: [],
283-
layers.ZeroPadding2D: [],
284-
layers.ZeroPadding3D: [],
285-
layers.Activation: [],
286-
layers.ActivityRegularization: [],
287271
layers.Dense: ['kernel'],
288-
layers.Dropout: [],
289-
layers.Flatten: [],
290-
layers.Lambda: [],
291-
layers.Masking: [],
292-
layers.Permute: [],
293-
layers.RepeatVector: [],
294-
layers.Reshape: [],
295-
layers.SpatialDropout1D: [],
296-
layers.SpatialDropout2D: [],
297-
layers.SpatialDropout3D: [],
298272
layers.Embedding: ['embeddings'],
299273
layers.LocallyConnected1D: ['kernel'],
300274
layers.LocallyConnected2D: ['kernel'],
301-
layers.Add: [],
302-
layers.Average: [],
303-
layers.Concatenate: [],
304-
layers.Dot: [],
305-
layers.Maximum: [],
306-
layers.Minimum: [],
307-
layers.Multiply: [],
308-
layers.Subtract: [],
309-
layers.AlphaDropout: [],
310-
layers.GaussianDropout: [],
311-
layers.GaussianNoise: [],
312275
layers.BatchNormalization: [],
313276
layers.LayerNormalization: [],
314-
layers.AveragePooling1D: [],
315-
layers.AveragePooling2D: [],
316-
layers.AveragePooling3D: [],
317-
layers.GlobalAveragePooling1D: [],
318-
layers.GlobalAveragePooling2D: [],
319-
layers.GlobalAveragePooling3D: [],
320-
layers.GlobalMaxPooling1D: [],
321-
layers.GlobalMaxPooling2D: [],
322-
layers.GlobalMaxPooling3D: [],
323-
layers.MaxPooling1D: [],
324-
layers.MaxPooling2D: [],
325-
layers.MaxPooling3D: [],
326277
}
327278

328279
_RNN_CELLS_WEIGHTS_MAP = {
@@ -369,6 +320,11 @@ def supports(cls, layer):
369320
True/False whether the layer type is supported.
370321
371322
"""
323+
# Automatically enable layers with zero trainable weights.
324+
# Example: Reshape, AveragePooling2D, Maximum/Minimum, etc.
325+
if len(layer.trainable_weights) == 0:
326+
return True
327+
372328
if layer.__class__ in cls._LAYERS_WEIGHTS_MAP:
373329
return True
374330

@@ -393,6 +349,10 @@ def _is_rnn_layer(cls, layer):
393349

394350
@classmethod
395351
def _weight_names(cls, layer):
352+
# For layers with zero trainable weights, like Reshape, Pooling.
353+
if len(layer.trainable_weights) == 0:
354+
return []
355+
396356
return cls._LAYERS_WEIGHTS_MAP[layer.__class__]
397357

398358
@classmethod

tensorflow_model_optimization/python/core/clustering/keras/clustering_registry_test.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,14 @@ def testConvolutionalWeightsCA(self,
157157

158158
class CustomLayer(layers.Layer):
159159
"""A custom non-clusterable layer class."""
160+
def __init__(self, units=10):
161+
super(CustomLayer, self).__init__()
162+
self.add_weight(shape=(1, units),
163+
initializer='uniform',
164+
name='kernel')
160165

166+
def call(self, inputs):
167+
return tf.matmul(inputs, self.weights)
161168

162169
class ClusteringLookupRegistryTest(test.TestCase, parameterized.TestCase):
163170
"""Unit tests for the ClusteringLookupRegistry class."""
@@ -256,6 +263,11 @@ class CustomLayerFromClusterableLayer(layers.Dense):
256263
"""A custom layer class derived from a built-in clusterable layer."""
257264
pass
258265

266+
class CustomLayerFromClusterableLayerNoWeights(layers.Reshape):
267+
"""A custom layer class derived from a built-in clusterable layer,
268+
that does not have any weights."""
269+
pass
270+
259271
class MinimalRNNCell(keras.layers.Layer):
260272
"""A minimal RNN cell implementation."""
261273

@@ -315,7 +327,10 @@ def testDoesNotSupportKerasUnsupportedLayer(self):
315327
Verifies that ClusterRegistry does not support an unknown built-in layer.
316328
"""
317329
# ConvLSTM2D is a built-in keras layer but not supported.
318-
self.assertFalse(ClusterRegistry.supports(layers.ConvLSTM2D(2, (5, 5))))
330+
l = layers.ConvLSTM2D(2, (5, 5))
331+
# We need to build weights
332+
l.build(input_shape = (10, 10))
333+
self.assertFalse(ClusterRegistry.supports(l))
319334

320335
def testSupportsKerasRNNLayers(self):
321336
"""
@@ -330,8 +345,10 @@ def testDoesNotSupportKerasRNNLayerUnknownCell(self):
330345
Verifies that ClusterRegistry does not support a custom non-clusterable RNN
331346
cell.
332347
"""
333-
self.assertFalse(ClusterRegistry.supports(
334-
keras.layers.RNN(ClusterRegistryTest.MinimalRNNCell(32))))
348+
l = keras.layers.RNN(ClusterRegistryTest.MinimalRNNCell(32))
349+
# We need to build it to have weights
350+
l.build((10,1))
351+
self.assertFalse(ClusterRegistry.supports(l))
335352

336353
def testSupportsKerasRNNLayerClusterableCell(self):
337354
"""
@@ -350,19 +367,31 @@ def testDoesNotSupportCustomLayer(self):
350367
def testDoesNotSupportCustomLayerInheritedFromClusterableLayer(self):
351368
"""
352369
Verifies that ClusterRegistry does not support a custom layer derived from
353-
a clusterable layer.
370+
a clusterable layer if there are trainable weights.
371+
"""
372+
custom_layer = ClusterRegistryTest.CustomLayerFromClusterableLayer(10)
373+
custom_layer.build(input_shape=(10, 10))
374+
self.assertFalse(ClusterRegistry.supports(custom_layer))
375+
376+
def testSupportsCustomLayerInheritedFromClusterableLayerNoWeights(self):
377+
"""
378+
Verifies that ClusterRegistry supports a custom layer derived from
379+
a clusterable layer that does not have trainable weights.
354380
"""
355-
self.assertFalse(
356-
ClusterRegistry.supports(
357-
ClusterRegistryTest.CustomLayerFromClusterableLayer(10)))
381+
custom_layer = ClusterRegistryTest.\
382+
CustomLayerFromClusterableLayerNoWeights((7, 1))
383+
custom_layer.build(input_shape=(3, 4))
384+
self.assertTrue(ClusterRegistry.supports(custom_layer))
358385

359386
def testMakeClusterableRaisesErrorForKerasUnsupportedLayer(self):
360387
"""
361388
Verifies that an unsupported built-in layer cannot be made clusterable by
362389
calling make_clusterable().
363390
"""
391+
l = layers.ConvLSTM2D(2, (5, 5))
392+
l.build(input_shape = (10, 10))
364393
with self.assertRaises(ValueError):
365-
ClusterRegistry.make_clusterable(layers.ConvLSTM2D(2, (5, 5)))
394+
ClusterRegistry.make_clusterable(l)
366395

367396
def testMakeClusterableRaisesErrorForCustomLayer(self):
368397
"""
@@ -378,9 +407,10 @@ def testMakeClusterableRaisesErrorForCustomLayerInheritedFromClusterableLayer(
378407
Verifies that a non-clusterable layer derived from a clusterable layer
379408
cannot be made clusterable by calling make_clusterable().
380409
"""
410+
l = ClusterRegistryTest.CustomLayerFromClusterableLayer(10)
411+
l.build(input_shape = (10, 10))
381412
with self.assertRaises(ValueError):
382-
ClusterRegistry.make_clusterable(
383-
ClusterRegistryTest.CustomLayerFromClusterableLayer(10))
413+
ClusterRegistry.make_clusterable(l)
384414

385415
def testMakeClusterableWorksOnKerasClusterableLayer(self):
386416
"""
@@ -478,9 +508,12 @@ def testMakeClusterableRaisesErrorOnRNNLayersUnsupportedCell(self):
478508
Verifies that make_clusterable() raises an exception when invoked with a
479509
built-in RNN layer that contains a non-clusterable custom RNN cell.
480510
"""
511+
l = ClusterRegistryTest.MinimalRNNCell(5)
512+
# we need to build weights
513+
l.build(input_shape = (10, 1))
481514
with self.assertRaises(ValueError):
482515
ClusterRegistry.make_clusterable(layers.RNN(
483-
[layers.LSTMCell(10), ClusterRegistryTest.MinimalRNNCell(5)]))
516+
[layers.LSTMCell(10), l]))
484517

485518

486519
if __name__ == '__main__':

0 commit comments

Comments
 (0)