Skip to content

Commit d31c3fc

Browse files
committed
Fixes for GitHub Issue #979
Change-Id: Idee3388e4905b5bf41f9598fa49ce200e9682bc6
1 parent c8fc87f commit d31c3fc

File tree

5 files changed

+194
-21
lines changed

5 files changed

+194
-21
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
# ==============================================================================
1515
"""Clustering API functions for Keras models."""
1616

17+
import distutils.version
18+
import warnings
19+
1720
import tensorflow as tf
1821

1922
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
@@ -266,6 +269,10 @@ def _add_clustering_wrapper(layer):
266269
**kwargs,
267270
)
268271

272+
# Skip clustering if Conv2D layer has insufficient number of weights for type of clustering
273+
if isinstance(layer, tf.keras.layers.Conv2D) and not layer_has_enough_weights_to_cluster(layer, number_of_clusters, cluster_per_channel):
274+
return layer
275+
269276
return cluster_wrapper.ClusterWeights(layer, number_of_clusters,
270277
cluster_centroids_init,
271278
preserve_sparsity,
@@ -355,3 +362,36 @@ def _strip_clustering_wrapper(layer):
355362
# Just copy the model with the right callback
356363
return tf.keras.models.clone_model(
357364
model, input_tensors=None, clone_function=_strip_clustering_wrapper)
365+
366+
def layer_has_enough_weights_to_cluster(
367+
layer, number_of_clusters, cluster_per_channel):
368+
"""Returns True if Conv2D layer has sufficient number of
369+
weights to implement clustering, given an input number of clusters."""
370+
if not isinstance(layer, tf.keras.layers.Conv2D):
371+
raise ValueError(
372+
f"Input layer should be Conv2D layer: {layer.name} given.")
373+
374+
if not layer.trainable_weights:
375+
raise ValueError(f"Layer {layer.name} has no weights to cluster.")
376+
377+
number_of_layer_weights = tf.cast(
378+
tf.size(getattr(layer,'kernel')), tf.int32)
379+
channel_idx = 1 if layer.data_format == "channels_first" else -1
380+
number_of_channels = tf.size(layer.trainable_weights[channel_idx])
381+
382+
if cluster_per_channel:
383+
weights_to_cluster = number_of_layer_weights / number_of_channels
384+
else:
385+
weights_to_cluster = number_of_layer_weights
386+
387+
if weights_to_cluster <= number_of_clusters:
388+
has_enough_weights = False
389+
else:
390+
has_enough_weights = True
391+
392+
if not has_enough_weights:
393+
warnings.warn(
394+
f"Layer {layer.name} does not have enough weights to implement"\
395+
f"{'per-channel ' if cluster_per_channel else ''}clustering."\
396+
f" \nNo clustering was implemented for this layer.\n")
397+
return has_enough_weights

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,65 @@ def do_checks(layer, layer_name):
288288
do_checks(clustered_model.layers[2], "conv1d")
289289
do_checks(clustered_model.layers[3], "conv1d_transpose")
290290

291+
@parameterized.parameters(
292+
(False, 16), # number_of_clusters > Conv2D filters
293+
(True, 8), # number_of_clusters < Conv2D filters (but clustering by channel)
294+
(True, 12), # number_of_clusters = Conv2D filters
295+
(False, 12), # number_of_clusters = Conv2D filters
296+
)
297+
def testEndToEnd1x1Conv2d(self, cluster_per_channel, number_of_clusters):
298+
"""Test End to End clustering - model with 1x1 Conv2D.
299+
Clustering should not be performed at all, since number of
300+
weights in the layer is too low in all of these cases.
301+
"""
302+
kernel_size = (1,1)
303+
304+
inp = keras.layers.Input(shape=(28, 28), batch_size=16)
305+
x = keras.layers.Reshape(target_shape=(28, 28, 1))(inp)
306+
x = keras.layers.Conv2D(filters=12, kernel_size=kernel_size,
307+
activation=tf.nn.relu)(x)
308+
model = keras.models.Model(inputs=inp, outputs=[x])
309+
310+
cluster_params = {
311+
"number_of_clusters": number_of_clusters,
312+
"cluster_per_channel": cluster_per_channel}
313+
314+
# Get unique kernel weights on original model for comparison
315+
original_unique_weights = model.layers[2].weights[0]
316+
317+
def apply_clustering(layer):
318+
if isinstance(layer, keras.layers.Conv2D):
319+
return cluster.cluster_weights(layer, **cluster_params)
320+
return layer
321+
322+
# Ensure a warning is given to the user that clustering is not implemented for this layer
323+
with self.assertWarnsRegex(Warning, r'Layer conv2d does not have enough weights'):
324+
model_to_cluster = keras.models.clone_model(
325+
model,
326+
clone_function=apply_clustering,
327+
)
328+
329+
model_to_cluster.compile(
330+
loss=keras.losses.categorical_crossentropy,
331+
optimizer="adam",
332+
metrics=["accuracy"]
333+
)
334+
model_to_cluster.fit(
335+
np.random.randn(*self._batch(model.input.get_shape().as_list(), 16)),
336+
np.random.randn(*self._batch(model.output.get_shape().as_list(), 16)),
337+
steps_per_epoch=1)
338+
clustered_model = cluster.strip_clustering(model_to_cluster)
339+
340+
def do_checks(layer, layer_name, original_unique_weights):
341+
self.assertEqual(layer.name, layer_name)
342+
unique_weights = layer.weights[0]
343+
344+
# Ensure clustering was not performed on the 1x1 Conv
345+
# (weights are identical to original unclustered layer)
346+
self.assertAllEqual(unique_weights, original_unique_weights)
347+
348+
do_checks(clustered_model.layers[2], "conv2d", original_unique_weights)
349+
291350
def testStripClusteringSequentialModelWithRegulariser(self):
292351
"""Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
293352
original_model = keras.Sequential([

tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,10 @@ def setUp(self):
145145
}
146146

147147
def _build_clustered_layer_model(self, layer, input_shape=(10, 1)):
148-
wrapped_layer = cluster.cluster_weights(layer, **self.params)
149-
self.model.add(wrapped_layer)
150-
self.model.build(input_shape=input_shape)
151-
148+
self.model.add(keras.Input(shape=input_shape))
149+
self.model.add(layer)
150+
self.model.build()
151+
wrapped_layer = cluster.cluster_weights(self.model.layers[0], **self.params)
152152
return wrapped_layer
153153

154154
def _validate_clustered_layer(self, original_layer, wrapped_layer):
@@ -194,7 +194,7 @@ def testClusterKerasNonClusterableLayer(self):
194194
def testDepthwiseConv2DLayerNonClusterable(self):
195195
"""Verifies that we don't cluster a DepthwiseConv2D layer, because clustering of this type of layer gives big unrecoverable accuracy loss."""
196196
wrapped_layer = self._build_clustered_layer_model(
197-
self.keras_depthwiseconv2d_layer, input_shape=(1, 10, 10, 10))
197+
self.keras_depthwiseconv2d_layer, input_shape=(10, 10, 10))
198198

199199
self._validate_clustered_layer(self.keras_depthwiseconv2d_layer,
200200
wrapped_layer)
@@ -203,7 +203,7 @@ def testDepthwiseConv2DLayerNonClusterable(self):
203203
@keras_parameterized.run_all_keras_modes
204204
def testDenseLayer(self):
205205
"""Verifies that we can cluster a Dense layer."""
206-
input_shape = (4, 28, 1)
206+
input_shape = (28, 1)
207207
wrapped_layer = self._build_clustered_layer_model(
208208
self.keras_dense_layer,
209209
input_shape=input_shape
@@ -217,7 +217,7 @@ def testDenseLayer(self):
217217
@keras_parameterized.run_all_keras_modes
218218
def testConv1DLayer(self):
219219
"""Verifies that we can cluster a Conv1D layer."""
220-
input_shape = (4, 28, 1)
220+
input_shape = (28, 1)
221221
wrapped_layer = self._build_clustered_layer_model(
222222
self.keras_conv1d_layer,
223223
input_shape=input_shape)
@@ -230,7 +230,7 @@ def testConv1DLayer(self):
230230
@keras_parameterized.run_all_keras_modes
231231
def testConv1DTransposeLayer(self):
232232
"""Verifies that we can cluster a Conv1DTranspose layer."""
233-
input_shape = (4, 28, 1)
233+
input_shape = (28, 1)
234234
wrapped_layer = self._build_clustered_layer_model(
235235
self.keras_conv1d_tr_layer,
236236
input_shape=input_shape)
@@ -243,7 +243,7 @@ def testConv1DTransposeLayer(self):
243243
@keras_parameterized.run_all_keras_modes
244244
def testConv2DLayer(self):
245245
"""Verifies that we can cluster a Conv2D layer."""
246-
input_shape = (4, 28, 28, 1)
246+
input_shape = (28, 28, 1)
247247
wrapped_layer = self._build_clustered_layer_model(
248248
self.keras_conv2d_layer,
249249
input_shape=input_shape)
@@ -256,7 +256,7 @@ def testConv2DLayer(self):
256256
@keras_parameterized.run_all_keras_modes
257257
def testConv2DTransposeLayer(self):
258258
"""Verifies that we can cluster a Conv2DTranspose layer."""
259-
input_shape = (4, 28, 28, 1)
259+
input_shape = (28, 28, 1)
260260
wrapped_layer = self._build_clustered_layer_model(
261261
self.keras_conv2d_tr_layer,
262262
input_shape=input_shape)
@@ -269,7 +269,7 @@ def testConv2DTransposeLayer(self):
269269
@keras_parameterized.run_all_keras_modes
270270
def testConv3DLayer(self):
271271
"""Verifies that we can cluster a Conv3D layer."""
272-
input_shape = (4, 28, 28, 28, 1)
272+
input_shape = (28, 28, 28, 1)
273273
wrapped_layer = self._build_clustered_layer_model(
274274
self.keras_conv3d_layer,
275275
input_shape=input_shape)
@@ -732,7 +732,7 @@ def testClusterWeightsStrippedWeights(self):
732732
def testStrippedKernel(self):
733733
"""Verifies that stripping the clustering wrappers from a functional model restores the layers kernel and the layers weight array to the new clustered weight value."""
734734
i1 = keras.Input(shape=(1, 1, 1))
735-
x1 = layers.Conv2D(1, 1)(i1)
735+
x1 = layers.Conv2D(12, 1)(i1)
736736
outputs = x1
737737
model = keras.Model(inputs=[i1], outputs=outputs)
738738

tensorflow_model_optimization/python/core/quantization/keras/collaborative_optimizations/cluster_preserve/cluster_preserve_integration_test.py

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,13 @@ def _get_clustered_model(self, preserve_sparsity):
103103

104104
return clustered_model
105105

106-
def _get_conv_model(self, nr_of_channels, data_format=None):
106+
def _get_conv_model(self, nr_of_channels, data_format=None, kernel_size=(3,3)):
107107
"""Returns functional model with Conv2D layer."""
108108
inp = tf.keras.layers.Input(shape=(32, 32), batch_size=100)
109109
shape = (1, 32, 32) if data_format == 'channels_first' else (32, 32, 1)
110110
x = tf.keras.layers.Reshape(shape)(inp)
111111
x = tf.keras.layers.Conv2D(
112-
filters=nr_of_channels, kernel_size=(3, 3),
112+
filters=nr_of_channels, kernel_size=kernel_size,
113113
data_format=data_format,
114114
activation='relu')(x)
115115
x = tf.keras.layers.MaxPool2D(2, 2)(x)
@@ -131,10 +131,10 @@ def _compile_and_fit_conv_model(self, model, nr_epochs=1):
131131
return model
132132

133133
def _get_conv_clustered_model(self, nr_of_channels, nr_of_clusters,
134-
data_format, preserve_sparsity):
134+
data_format, preserve_sparsity, kernel_size=(3,3)):
135135
"""Returns clustered per channel model with Conv2D layer."""
136136
tf.random.set_seed(42)
137-
model = self._get_conv_model(nr_of_channels, data_format)
137+
model = self._get_conv_model(nr_of_channels, data_format, kernel_size)
138138

139139
if preserve_sparsity:
140140
# Make the convolutional layer sparse by nullifying half of weights
@@ -475,6 +475,75 @@ def testEndToEndPCQATClusteredPerChannel(self, data_format='channels_last'):
475475
cqat_sparsity = self._get_sparsity(stripped_cqat_model)
476476
self.assertLessEqual(cqat_sparsity[0], control_sparsity[0])
477477

478+
def testEndToEndPCQATClusteredPerChannelConv2d1x1(
479+
self, data_format='channels_last'):
480+
"""Runs PCQAT for model containing a 1x1 Conv2D
481+
(with insufficient number of weights per channel)."""
482+
nr_of_channels = 12
483+
nr_of_clusters = 4
484+
485+
# Ensure a warning is given to the user that
486+
# clustering is not implemented for this layer
487+
with self.assertWarnsRegex(
488+
Warning, r'Layer conv2d does not have enough weights'):
489+
clustered_model = self._get_conv_clustered_model(
490+
nr_of_channels,
491+
nr_of_clusters,
492+
data_format,
493+
preserve_sparsity=True,
494+
kernel_size=(1,1))
495+
stripped_model = cluster.strip_clustering(clustered_model)
496+
497+
# Save the kernel weights
498+
conv2d_layer = stripped_model.layers[2]
499+
self.assertEqual(conv2d_layer.name, 'conv2d')
500+
501+
for weight in conv2d_layer.weights:
502+
if 'kernel' in weight.name:
503+
# Original number of unique weights
504+
nr_original_weights = len(np.unique(weight.numpy()))
505+
self.assertLess(nr_original_weights, nr_of_channels*nr_of_clusters)
506+
507+
# Demonstrate unmodified test layer has less weights
508+
# than requested clusters
509+
for channel in range(nr_of_channels):
510+
channel_weights = (
511+
weight[:, channel, :, :]
512+
if data_format == "channels_first" else weight[:, :, :, channel])
513+
nr_channel_weights = len(channel_weights)
514+
self.assertGreater(nr_channel_weights, 0)
515+
self.assertLessEqual(nr_channel_weights, nr_of_clusters)
516+
517+
# get sparsity before PCQAT training
518+
# we expect that only one value will be returned
519+
control_sparsity = self._get_sparsity(stripped_model)
520+
self.assertGreater(control_sparsity[0], 0.5)
521+
522+
quant_aware_annotate_model = (
523+
quantize.quantize_annotate_model(stripped_model)
524+
)
525+
526+
with self.assertWarnsRegex(Warning, r'No clustering performed on layer quant_conv2d'):
527+
quant_aware_model = quantize.quantize_apply(
528+
quant_aware_annotate_model,
529+
scheme=default_8bit_cluster_preserve_quantize_scheme
530+
.Default8BitClusterPreserveQuantizeScheme(preserve_sparsity=True))
531+
532+
# Lets train for more epochs to have a chance to scatter clusters
533+
model = self._compile_and_fit_conv_model(quant_aware_model, 3)
534+
535+
stripped_cqat_model = strip_clustering_cqat(model)
536+
537+
# Check the unique weights of a certain layer of
538+
# clustered_model and cqat_model, ensuring unchanged
539+
layer_nr = 3
540+
num_of_unique_weights_cqat = self._get_number_of_unique_weights(
541+
stripped_cqat_model, layer_nr, 'kernel')
542+
self.assertEqual(num_of_unique_weights_cqat, nr_original_weights)
543+
544+
cqat_sparsity = self._get_sparsity(stripped_cqat_model)
545+
self.assertLessEqual(cqat_sparsity[0], control_sparsity[0])
546+
478547
def testPassingNonPrunedModelToPCQAT(self):
479548
"""Runs PCQAT as CQAT if the input model is not pruned."""
480549
preserve_sparsity = False

tensorflow_model_optimization/python/core/quantization/keras/collaborative_optimizations/cluster_preserve/cluster_preserve_quantize_registry.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
"""Registry responsible for built-in keras classes."""
1616

1717
import logging
18+
import warnings
1819

1920
import tensorflow as tf
21+
from tensorflow.python.keras import backend as K
2022

2123
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
2224
from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry
@@ -259,6 +261,7 @@ def apply_cluster_preserve_quantize_config(self, layer, quantize_config):
259261
if self._no_trainable_weights(layer) or self._disable_cluster_preserve(
260262
layer):
261263
return quantize_config
264+
262265
# Example: Conv2D, Dense layers
263266
if quantize_config.__class__.__name__ in self._LAYERS_CONFIG_MAP[
264267
layer.__class__].quantize_config_attrs:
@@ -277,11 +280,6 @@ class Default8bitClusterPreserveQuantizeRegistry(
277280
ClusterPreserveQuantizeRegistry):
278281
"""Default 8 bit ClusterPreserveQuantizeRegistry."""
279282

280-
def __init__(self, preserve_sparsity):
281-
super(Default8bitClusterPreserveQuantizeRegistry, self).__init__(
282-
preserve_sparsity)
283-
self.preserve_sparsity = preserve_sparsity
284-
285283
def get_quantize_config(self, layer):
286284
"""Returns the quantization config with weight_quantizer for a given layer.
287285
@@ -364,6 +362,13 @@ def _build_clusters(self, name, layer):
364362
# Prepare clustering variables for the Keras graph when clusters
365363
# exist, assuming we do not use number_of_clusters larger than 1024
366364
if num_centroids > 1024:
365+
warnings.warn(f"No clustering performed on layer {layer.name}.\n" \
366+
"Too many centroids to cluster.")
367+
return result
368+
# If not enough clusters, we do not preserve clustering
369+
elif num_centroids <= 1:
370+
warnings.warn(f"No clustering performed on layer {layer.name}.\n" \
371+
"Perhaps too many clusters requested for this layer?")
367372
return result
368373
else:
369374
clst_centroids_tf = layer.add_weight(

0 commit comments

Comments
 (0)