Skip to content

Commit 9dc53f6

Browse files
Merge pull request #986 from jamwar01:patch_issue_979
PiperOrigin-RevId: 460333065
2 parents 24603dd + cbc0629 commit 9dc53f6

File tree

5 files changed

+229
-23
lines changed

5 files changed

+229
-23
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# ==============================================================================
1515
"""Clustering API functions for Keras models."""
1616

17+
import warnings
18+
1719
import tensorflow as tf
1820

1921
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
@@ -266,6 +268,14 @@ def _add_clustering_wrapper(layer):
266268
**kwargs,
267269
)
268270

271+
# Skip clustering if Conv2D layer has insufficient number of weights
272+
# for type of clustering
273+
if isinstance(
274+
layer,
275+
tf.keras.layers.Conv2D) and not layer_has_enough_weights_to_cluster(
276+
layer, number_of_clusters, cluster_per_channel):
277+
return layer
278+
269279
return cluster_wrapper.ClusterWeights(layer, number_of_clusters,
270280
cluster_centroids_init,
271281
preserve_sparsity,
@@ -355,3 +365,43 @@ def _strip_clustering_wrapper(layer):
355365
# Just copy the model with the right callback
356366
return tf.keras.models.clone_model(
357367
model, input_tensors=None, clone_function=_strip_clustering_wrapper)
368+
369+
370+
def layer_has_enough_weights_to_cluster(layer, number_of_clusters,
371+
cluster_per_channel):
372+
"""Returns whether layer has enough weights to cluster.
373+
374+
Returns True if Conv2D layer has sufficient number of
375+
weights to implement clustering, given an input number of clusters.
376+
377+
Args:
378+
layer: input layer to return quantize configs for.
379+
number_of_clusters: A number of cluster centroids to form clusters.
380+
cluster_per_channel: An optional boolean value.
381+
"""
382+
if not isinstance(layer, tf.keras.layers.Conv2D):
383+
raise ValueError(f'Input layer should be Conv2D layer: {layer.name} given.')
384+
385+
if not layer.trainable_weights:
386+
raise ValueError(f'Layer {layer.name} has no weights to cluster.')
387+
388+
number_of_layer_weights = tf.cast(tf.size(getattr(layer, 'kernel')), tf.int32)
389+
channel_idx = 1 if layer.data_format == 'channels_first' else -1
390+
number_of_channels = tf.size(layer.trainable_weights[channel_idx])
391+
392+
if cluster_per_channel:
393+
weights_to_cluster = number_of_layer_weights / number_of_channels
394+
else:
395+
weights_to_cluster = number_of_layer_weights
396+
397+
if weights_to_cluster <= number_of_clusters:
398+
has_enough_weights = False
399+
else:
400+
has_enough_weights = True
401+
402+
if not has_enough_weights:
403+
warnings.warn(
404+
f"Layer {layer.name} does not have enough weights to implement"
405+
f"{'per-channel ' if cluster_per_channel else ''}clustering."
406+
f" \nNo clustering was implemented for this layer.\n")
407+
return has_enough_weights

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,74 @@ def do_checks(layer, layer_name):
288288
do_checks(clustered_model.layers[2], "conv1d")
289289
do_checks(clustered_model.layers[3], "conv1d_transpose")
290290

291+
@parameterized.parameters(
292+
(False, 16), # number_of_clusters > Conv2D filters
293+
(True,
294+
8), # number_of_clusters < Conv2D filters (but clustering by channel)
295+
(True, 12), # number_of_clusters = Conv2D filters
296+
(False, 12), # number_of_clusters = Conv2D filters
297+
)
298+
def testEndToEnd1x1Conv2d(self, cluster_per_channel, number_of_clusters):
299+
"""Test End to End clustering - model with 1x1 Conv2D.
300+
301+
Clustering should not be performed at all, since number of
302+
weights in the layer is too low in all of these cases.
303+
304+
Args:
305+
cluster_per_channel: An optional boolean value.
306+
number_of_clusters: A number of cluster centroids to form clusters.
307+
"""
308+
kernel_size = (1, 1)
309+
310+
inp = keras.layers.Input(shape=(28, 28), batch_size=16)
311+
x = keras.layers.Reshape(target_shape=(28, 28, 1))(inp)
312+
x = keras.layers.Conv2D(
313+
filters=12, kernel_size=kernel_size, activation=tf.nn.relu)(
314+
x)
315+
model = keras.models.Model(inputs=inp, outputs=[x])
316+
317+
cluster_params = {
318+
"number_of_clusters": number_of_clusters,
319+
"cluster_per_channel": cluster_per_channel
320+
}
321+
322+
# Get unique kernel weights on original model for comparison
323+
original_unique_weights = model.layers[2].weights[0]
324+
325+
def apply_clustering(layer):
326+
if isinstance(layer, keras.layers.Conv2D):
327+
return cluster.cluster_weights(layer, **cluster_params)
328+
return layer
329+
330+
# Ensure a warning is given to the user that clustering is not
331+
# implemented for this layer
332+
with self.assertWarnsRegex(Warning,
333+
r"Layer conv2d does not have enough weights"):
334+
model_to_cluster = keras.models.clone_model(
335+
model,
336+
clone_function=apply_clustering,
337+
)
338+
339+
model_to_cluster.compile(
340+
loss=keras.losses.categorical_crossentropy,
341+
optimizer="adam",
342+
metrics=["accuracy"])
343+
model_to_cluster.fit(
344+
np.random.randn(*self._batch(model.input.get_shape().as_list(), 16)),
345+
np.random.randn(*self._batch(model.output.get_shape().as_list(), 16)),
346+
steps_per_epoch=1)
347+
clustered_model = cluster.strip_clustering(model_to_cluster)
348+
349+
def do_checks(layer, layer_name, original_unique_weights):
350+
self.assertEqual(layer.name, layer_name)
351+
unique_weights = layer.weights[0]
352+
353+
# Ensure clustering was not performed on the 1x1 Conv
354+
# (weights are identical to original unclustered layer)
355+
self.assertAllEqual(unique_weights, original_unique_weights)
356+
357+
do_checks(clustered_model.layers[2], "conv2d", original_unique_weights)
358+
291359
def testStripClusteringSequentialModelWithRegulariser(self):
292360
"""Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
293361
original_model = keras.Sequential([

tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,10 @@ def setUp(self):
145145
}
146146

147147
def _build_clustered_layer_model(self, layer, input_shape=(10, 1)):
148-
wrapped_layer = cluster.cluster_weights(layer, **self.params)
149-
self.model.add(wrapped_layer)
150-
self.model.build(input_shape=input_shape)
151-
148+
self.model.add(keras.Input(shape=input_shape))
149+
self.model.add(layer)
150+
self.model.build()
151+
wrapped_layer = cluster.cluster_weights(self.model.layers[0], **self.params)
152152
return wrapped_layer
153153

154154
def _validate_clustered_layer(self, original_layer, wrapped_layer):
@@ -194,7 +194,7 @@ def testClusterKerasNonClusterableLayer(self):
194194
def testDepthwiseConv2DLayerNonClusterable(self):
195195
"""Verifies that we don't cluster a DepthwiseConv2D layer, because clustering of this type of layer gives big unrecoverable accuracy loss."""
196196
wrapped_layer = self._build_clustered_layer_model(
197-
self.keras_depthwiseconv2d_layer, input_shape=(1, 10, 10, 10))
197+
self.keras_depthwiseconv2d_layer, input_shape=(10, 10, 10))
198198

199199
self._validate_clustered_layer(self.keras_depthwiseconv2d_layer,
200200
wrapped_layer)
@@ -203,7 +203,7 @@ def testDepthwiseConv2DLayerNonClusterable(self):
203203
@keras_parameterized.run_all_keras_modes
204204
def testDenseLayer(self):
205205
"""Verifies that we can cluster a Dense layer."""
206-
input_shape = (4, 28, 1)
206+
input_shape = (28, 1)
207207
wrapped_layer = self._build_clustered_layer_model(
208208
self.keras_dense_layer,
209209
input_shape=input_shape
@@ -217,7 +217,7 @@ def testDenseLayer(self):
217217
@keras_parameterized.run_all_keras_modes
218218
def testConv1DLayer(self):
219219
"""Verifies that we can cluster a Conv1D layer."""
220-
input_shape = (4, 28, 1)
220+
input_shape = (28, 1)
221221
wrapped_layer = self._build_clustered_layer_model(
222222
self.keras_conv1d_layer,
223223
input_shape=input_shape)
@@ -230,7 +230,7 @@ def testConv1DLayer(self):
230230
@keras_parameterized.run_all_keras_modes
231231
def testConv1DTransposeLayer(self):
232232
"""Verifies that we can cluster a Conv1DTranspose layer."""
233-
input_shape = (4, 28, 1)
233+
input_shape = (28, 1)
234234
wrapped_layer = self._build_clustered_layer_model(
235235
self.keras_conv1d_tr_layer,
236236
input_shape=input_shape)
@@ -243,7 +243,7 @@ def testConv1DTransposeLayer(self):
243243
@keras_parameterized.run_all_keras_modes
244244
def testConv2DLayer(self):
245245
"""Verifies that we can cluster a Conv2D layer."""
246-
input_shape = (4, 28, 28, 1)
246+
input_shape = (28, 28, 1)
247247
wrapped_layer = self._build_clustered_layer_model(
248248
self.keras_conv2d_layer,
249249
input_shape=input_shape)
@@ -256,7 +256,7 @@ def testConv2DLayer(self):
256256
@keras_parameterized.run_all_keras_modes
257257
def testConv2DTransposeLayer(self):
258258
"""Verifies that we can cluster a Conv2DTranspose layer."""
259-
input_shape = (4, 28, 28, 1)
259+
input_shape = (28, 28, 1)
260260
wrapped_layer = self._build_clustered_layer_model(
261261
self.keras_conv2d_tr_layer,
262262
input_shape=input_shape)
@@ -269,7 +269,7 @@ def testConv2DTransposeLayer(self):
269269
@keras_parameterized.run_all_keras_modes
270270
def testConv3DLayer(self):
271271
"""Verifies that we can cluster a Conv3D layer."""
272-
input_shape = (4, 28, 28, 28, 1)
272+
input_shape = (28, 28, 28, 1)
273273
wrapped_layer = self._build_clustered_layer_model(
274274
self.keras_conv3d_layer,
275275
input_shape=input_shape)
@@ -732,7 +732,7 @@ def testClusterWeightsStrippedWeights(self):
732732
def testStrippedKernel(self):
733733
"""Verifies that stripping the clustering wrappers from a functional model restores the layers kernel and the layers weight array to the new clustered weight value."""
734734
i1 = keras.Input(shape=(1, 1, 1))
735-
x1 = layers.Conv2D(1, 1)(i1)
735+
x1 = layers.Conv2D(12, 1)(i1)
736736
outputs = x1
737737
model = keras.Model(inputs=[i1], outputs=outputs)
738738

tensorflow_model_optimization/python/core/quantization/keras/collaborative_optimizations/cluster_preserve/cluster_preserve_integration_test.py

Lines changed: 90 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,20 @@ def _get_clustered_model(self, preserve_sparsity):
103103

104104
return clustered_model
105105

106-
def _get_conv_model(self, nr_of_channels, data_format=None):
106+
def _get_conv_model(self,
107+
nr_of_channels,
108+
data_format=None,
109+
kernel_size=(3, 3)):
107110
"""Returns functional model with Conv2D layer."""
108111
inp = tf.keras.layers.Input(shape=(32, 32), batch_size=100)
109112
shape = (1, 32, 32) if data_format == 'channels_first' else (32, 32, 1)
110113
x = tf.keras.layers.Reshape(shape)(inp)
111114
x = tf.keras.layers.Conv2D(
112-
filters=nr_of_channels, kernel_size=(3, 3),
115+
filters=nr_of_channels,
116+
kernel_size=kernel_size,
113117
data_format=data_format,
114-
activation='relu')(x)
118+
activation='relu')(
119+
x)
115120
x = tf.keras.layers.MaxPool2D(2, 2)(x)
116121
out = tf.keras.layers.Flatten()(x)
117122
model = tf.keras.Model(inputs=inp, outputs=out)
@@ -130,11 +135,15 @@ def _compile_and_fit_conv_model(self, model, nr_epochs=1):
130135

131136
return model
132137

133-
def _get_conv_clustered_model(self, nr_of_channels, nr_of_clusters,
134-
data_format, preserve_sparsity):
138+
def _get_conv_clustered_model(self,
139+
nr_of_channels,
140+
nr_of_clusters,
141+
data_format,
142+
preserve_sparsity,
143+
kernel_size=(3, 3)):
135144
"""Returns clustered per channel model with Conv2D layer."""
136145
tf.random.set_seed(42)
137-
model = self._get_conv_model(nr_of_channels, data_format)
146+
model = self._get_conv_model(nr_of_channels, data_format, kernel_size)
138147

139148
if preserve_sparsity:
140149
# Make the convolutional layer sparse by nullifying half of weights
@@ -475,6 +484,81 @@ def testEndToEndPCQATClusteredPerChannel(self, data_format='channels_last'):
475484
cqat_sparsity = self._get_sparsity(stripped_cqat_model)
476485
self.assertLessEqual(cqat_sparsity[0], control_sparsity[0])
477486

487+
def testEndToEndPCQATClusteredPerChannelConv2d1x1(self,
488+
data_format='channels_last'
489+
):
490+
"""Runs PCQAT for model containing a 1x1 Conv2D.
491+
492+
(with insufficient number of weights per channel).
493+
494+
Args:
495+
data_format: Format of input data.
496+
"""
497+
nr_of_channels = 12
498+
nr_of_clusters = 4
499+
500+
# Ensure a warning is given to the user that
501+
# clustering is not implemented for this layer
502+
with self.assertWarnsRegex(Warning,
503+
r'Layer conv2d does not have enough weights'):
504+
clustered_model = self._get_conv_clustered_model(
505+
nr_of_channels,
506+
nr_of_clusters,
507+
data_format,
508+
preserve_sparsity=True,
509+
kernel_size=(1, 1))
510+
stripped_model = cluster.strip_clustering(clustered_model)
511+
512+
# Save the kernel weights
513+
conv2d_layer = stripped_model.layers[2]
514+
self.assertEqual(conv2d_layer.name, 'conv2d')
515+
516+
for weight in conv2d_layer.weights:
517+
if 'kernel' in weight.name:
518+
# Original number of unique weights
519+
nr_original_weights = len(np.unique(weight.numpy()))
520+
self.assertLess(nr_original_weights, nr_of_channels * nr_of_clusters)
521+
522+
# Demonstrate unmodified test layer has less weights
523+
# than requested clusters
524+
for channel in range(nr_of_channels):
525+
channel_weights = (
526+
weight[:, channel, :, :]
527+
if data_format == 'channels_first' else weight[:, :, :, channel])
528+
nr_channel_weights = len(channel_weights)
529+
self.assertGreater(nr_channel_weights, 0)
530+
self.assertLessEqual(nr_channel_weights, nr_of_clusters)
531+
532+
# get sparsity before PCQAT training
533+
# we expect that only one value will be returned
534+
control_sparsity = self._get_sparsity(stripped_model)
535+
self.assertGreater(control_sparsity[0], 0.5)
536+
537+
quant_aware_annotate_model = (
538+
quantize.quantize_annotate_model(stripped_model))
539+
540+
with self.assertWarnsRegex(
541+
Warning, r'No clustering performed on layer quant_conv2d'):
542+
quant_aware_model = quantize.quantize_apply(
543+
quant_aware_annotate_model,
544+
scheme=default_8bit_cluster_preserve_quantize_scheme
545+
.Default8BitClusterPreserveQuantizeScheme(preserve_sparsity=True))
546+
547+
# Lets train for more epochs to have a chance to scatter clusters
548+
model = self._compile_and_fit_conv_model(quant_aware_model, 3)
549+
550+
stripped_cqat_model = strip_clustering_cqat(model)
551+
552+
# Check the unique weights of a certain layer of
553+
# clustered_model and cqat_model, ensuring unchanged
554+
layer_nr = 3
555+
num_of_unique_weights_cqat = self._get_number_of_unique_weights(
556+
stripped_cqat_model, layer_nr, 'kernel')
557+
self.assertEqual(num_of_unique_weights_cqat, nr_original_weights)
558+
559+
cqat_sparsity = self._get_sparsity(stripped_cqat_model)
560+
self.assertLessEqual(cqat_sparsity[0], control_sparsity[0])
561+
478562
def testPassingNonPrunedModelToPCQAT(self):
479563
"""Runs PCQAT as CQAT if the input model is not pruned."""
480564
preserve_sparsity = False

tensorflow_model_optimization/python/core/quantization/keras/collaborative_optimizations/cluster_preserve/cluster_preserve_quantize_registry.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Registry responsible for built-in keras classes."""
1616

1717
import logging
18+
import warnings
1819

1920
import tensorflow as tf
2021

@@ -259,6 +260,7 @@ def apply_cluster_preserve_quantize_config(self, layer, quantize_config):
259260
if self._no_trainable_weights(layer) or self._disable_cluster_preserve(
260261
layer):
261262
return quantize_config
263+
262264
# Example: Conv2D, Dense layers
263265
if quantize_config.__class__.__name__ in self._LAYERS_CONFIG_MAP[
264266
layer.__class__].quantize_config_attrs:
@@ -277,11 +279,6 @@ class Default8bitClusterPreserveQuantizeRegistry(
277279
ClusterPreserveQuantizeRegistry):
278280
"""Default 8 bit ClusterPreserveQuantizeRegistry."""
279281

280-
def __init__(self, preserve_sparsity):
281-
super(Default8bitClusterPreserveQuantizeRegistry, self).__init__(
282-
preserve_sparsity)
283-
self.preserve_sparsity = preserve_sparsity
284-
285282
def get_quantize_config(self, layer):
286283
"""Returns the quantization config with weight_quantizer for a given layer.
287284
@@ -364,6 +361,13 @@ def _build_clusters(self, name, layer):
364361
# Prepare clustering variables for the Keras graph when clusters
365362
# exist, assuming we do not use number_of_clusters larger than 1024
366363
if num_centroids > 1024:
364+
warnings.warn(f'No clustering performed on layer {layer.name}.\n'
365+
f'Too many centroids to cluster.')
366+
return result
367+
# If not enough clusters, we do not preserve clustering
368+
elif num_centroids <= 1:
369+
warnings.warn(f'No clustering performed on layer {layer.name}.\n'
370+
f'Perhaps too many clusters requested for this layer?')
367371
return result
368372
else:
369373
clst_centroids_tf = layer.add_weight(

0 commit comments

Comments
 (0)