Skip to content

Commit a9cd0b0

Browse files
committed
Support of clustering per channel in (P)CQAT.
Change-Id: I70d28b00f335b044e0fed3db7779a6cbbd56ded8
1 parent b278157 commit a9cd0b0

File tree

6 files changed

+430
-35
lines changed

6 files changed

+430
-35
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ def testMHA(self):
586586

587587
class ClusterPerChannelIntegrationTest(tf.test.TestCase,
588588
parameterized.TestCase):
589-
"""Integration tests for per-channel clustering of Conv2D layer."""
589+
"""Integration test for cluster_per_channel for Conv2D layer."""
590590

591591
def setUp(self):
592592
super(ClusterPerChannelIntegrationTest, self).setUp()
@@ -624,15 +624,27 @@ def testPerChannel(self):
624624
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
625625
metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")])
626626
clustered_model.fit(
627-
self.x_train, self.y_train, epochs=1, batch_size=100, verbose=1)
627+
self.x_train, self.y_train, epochs=2, batch_size=100, verbose=1)
628628

629629
stripped_model = cluster.strip_clustering(clustered_model)
630630

631631
layer_conv2d = stripped_model.layers[2]
632+
layer_conv2d_kernel_weight = None
632633
for weight in layer_conv2d.weights:
633634
if "kernel" in weight.name:
634-
nr_unique_weights = len(np.unique(weight.numpy()))
635-
assert nr_unique_weights == self.nr_of_clusters*self.num_channels
635+
layer_conv2d_kernel_weight = weight
636+
assert layer_conv2d_kernel_weight is not None
637+
nr_unique_weights = len(np.unique(layer_conv2d_kernel_weight.numpy()))
638+
assert nr_unique_weights == self.nr_of_clusters*self.num_channels
639+
640+
# The above check is too general.
641+
# We need to check that we actually have nr_of_clusters per channel.
642+
# If more general case passed, then we do tighter check.
643+
# Note that we assume that data_format is NHWC.
644+
for i in range(self.num_channels):
645+
nr_unique_weights_per_channel = len(np.unique(
646+
layer_conv2d_kernel_weight[:, :, :, i]))
647+
assert nr_unique_weights_per_channel == self.nr_of_clusters
636648

637649
if __name__ == "__main__":
638650
test.main()

tensorflow_model_optimization/python/core/clustering/keras/clustering_algorithm.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(
5353
clusters centroids.
5454
cluster_gradient_aggregation: An enum that specify the aggregation method
5555
of the cluster gradient.
56-
data_format: To be used in Per-Channel clustering to ensure the weight
56+
data_format: To be used in cluster_per_channel to ensure the weight
5757
kernel is permuted properly when updating the weights and calculating
5858
gradients
5959
"""
@@ -194,10 +194,24 @@ def get_clustered_weight(self, pulling_indices, original_weight):
194194
return clustered_weight
195195

196196

197-
class PerChannelCA(ClusteringAlgorithm):
197+
class ClusteringAlgorithmPerChannel(ClusteringAlgorithm):
198198
"""Class for Per-channel clustering of Conv2D layers."""
199199

200200
def get_pulling_indices(self, weight):
201+
"""Returns indices of closest cluster centroids.
202+
203+
This function is based on the function get_pulling_indices
204+
of the base class ClusteringAlgorithm. We apply each per
205+
channel of the convolutional layer.
206+
207+
Args:
208+
weight: ND array of weights. For each weight in this array the closest
209+
cluster centroids is found.
210+
211+
Returns:
212+
ND array of the same shape as `weight` parameter of the type
213+
tf.int32. The returned array contain weight lookup indices.
214+
"""
201215
channel_indices = []
202216

203217
num_channels = (weight.shape[1] if self.data_format == "channels_first"
@@ -213,7 +227,13 @@ def get_pulling_indices(self, weight):
213227

214228
channel_indices.append(pulling_indices)
215229

216-
return tf.convert_to_tensor(channel_indices)
230+
pulling_indices = tf.convert_to_tensor(channel_indices)
231+
pulling_indices = (
232+
tf.transpose(pulling_indices, perm=[1, 0, 2, 3])
233+
if self.data_format == "channels_first" else tf.transpose(
234+
pulling_indices, perm=[1, 2, 3, 0]))
235+
236+
return pulling_indices
217237

218238
def get_clustered_weight(self, pulling_indices, original_weight):
219239
"""Returns clustered weights with custom gradients.
@@ -240,6 +260,16 @@ def get_clustered_weight(self, pulling_indices, original_weight):
240260
original_weight.shape[1]
241261
if self.data_format == "channels_first" else original_weight.shape[-1])
242262

263+
# In case of channels_last, we have NHWC
264+
# In case of channels_first, we have NCHW
265+
266+
# We need to transpose the tensor, so C is the first dimension
267+
# and then we could loop over channels
268+
pulling_indices = (
269+
tf.transpose(pulling_indices, perm=[1, 0, 2, 3])
270+
if self.data_format == "channels_first" else tf.transpose(
271+
pulling_indices, perm=[3, 0, 1, 2]))
272+
243273
if self.cluster_gradient_aggregation == GradientAggregation.SUM:
244274
cluster_centroids = self.cluster_centroids
245275
elif self.cluster_gradient_aggregation == GradientAggregation.AVG:
@@ -268,7 +298,7 @@ def get_clustered_weight(self, pulling_indices, original_weight):
268298

269299
for i in range(num_channels):
270300
clustered_weights.append(
271-
tf.gather(cluster_centroids[i], pulling_indices[i]))
301+
tf.gather(cluster_centroids[i], pulling_indices[i]))
272302

273303
clustered_weight = tf.convert_to_tensor(clustered_weights)
274304

tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
layers = tf.keras.layers
2323
ClusteringAlgorithm = clustering_algorithm.ClusteringAlgorithm
24-
PerChannelCA = clustering_algorithm.PerChannelCA
24+
ClusteringAlgorithmPerChannel = clustering_algorithm.ClusteringAlgorithmPerChannel
2525

2626

2727
class ClusteringLookupRegistry(object):
@@ -43,7 +43,7 @@ def get_clustering_impl(cls, layer, weight_name, cluster_per_channel=False):
4343
# Per-channel clustering is only applied if the layer is a Conv2D,
4444
# ignored otherwise
4545
if cluster_per_channel and isinstance(layer, tf.keras.layers.Conv2D):
46-
return PerChannelCA
46+
return ClusteringAlgorithmPerChannel
4747

4848
# Clusterable layer could provide own implementation of get_pulling_indices
4949
if (issubclass(layer.__class__, clusterable_layer.ClusterableLayer) and

tensorflow_model_optimization/python/core/clustering/keras/clustering_registry_test.py

Lines changed: 98 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -184,43 +184,91 @@ def testConvolutionalWeightsCA(self, clustering_centroids, pulling_indices,
184184
self._check_pull_values(clustering_algo, pulling_indices, expected_output)
185185

186186
@parameterized.parameters(
187-
([[0., 1, 2], [3, 4, 5]],
188-
[[[[0], [0]], [[0], [1]]],
189-
[[[0], [2]], [[1], [0]]]],
190-
[[[[0], [0]], [[0], [0]]],
191-
[[[0], [0]], [[1], [1]]]]))
187+
("channels_last",
188+
[[1, 2], [3, 4], [5, 6]], # 3 channels and 2 cluster per channel
189+
# pulling indices has shape (2, 2, 1, 3)
190+
[[[[0, 1, 0]], [[0, 1, 1]]], [[[1, 0, 1]], [[0, 1, 0]]]],
191+
[[[[1, 4, 5]], [[1, 4, 6]]], [[[2, 3, 6]], [[1, 4, 5]]]]),
192+
("channels_first",
193+
[[1, 2], [3, 4], [4, 5], [6, 7]], # 4 channels and 2 clusters per channel
194+
# pulling indices has shape (1, 4, 2, 2)
195+
[[[[0, 1], [1, 1]], [[0, 0], [0, 1]],
196+
[[1, 0], [0, 0]], [[1, 1], [0, 0]]]],
197+
[[[[1, 2], [2, 2]], [[3, 3], [3, 4]],
198+
[[5, 4], [4, 4]], [[7, 7], [6, 6]]]])
199+
)
192200
def testConvolutionalWeightsPerChannelCA(self,
201+
data_format,
193202
clustering_centroids,
194203
pulling_indices,
195204
expected_output):
196-
"""Verifies that PerChannelCA works as expected."""
205+
"""Verifies that get_clustered_weight function works as expected."""
197206
clustering_centroids = tf.Variable(clustering_centroids, dtype=tf.float32)
198-
clustering_algo = clustering_registry.PerChannelCA(
199-
clustering_centroids, GradientAggregation.SUM
207+
clustering_algo = clustering_registry.ClusteringAlgorithmPerChannel(
208+
clustering_centroids, GradientAggregation.SUM, data_format
200209
)
210+
# Note that clustered_weights has the same shape as pulling_indices,
211+
# because they are defined inside of the check function.
201212
self._check_pull_values(clustering_algo, pulling_indices, expected_output)
202213

214+
@parameterized.parameters(
215+
("channels_last",
216+
[[1, 2], [3, 4], [5, 6]], # 3 channels and 2 cluster per channel
217+
# weight has shape (2, 2, 1, 3)
218+
[[[[1.1, 3.2, 5.2]], [[2.0, 4.1, 5.2]]],
219+
[[[2.1, 2., 6.1]], [[1., 5., 5.]]]],
220+
# expected pulling indices
221+
[[[[0, 0, 0]], [[1, 1, 0]]], [[[1, 0, 1]], [[0, 1, 0]]]]),
222+
("channels_first",
223+
# 4 channels and 2 clusters per channel
224+
[[1, 2], [3, 4], [4, 5], [6, 7]],
225+
# weight has shape (1, 4, 2, 2)
226+
[[[[0.1, 1.5], [2.0, 1.1]], [[0., 3.5], [4.4, 4.]],
227+
[[4.1, 4.2], [5.3, 6.]], [[7., 7.1], [6.1, 5.8]]]],
228+
# expected pulling indices
229+
[[[[0, 0], [1, 0]], [[0, 0], [1, 1]],
230+
[[0, 0], [1, 1]], [[1, 1], [0, 0]]]])
231+
)
232+
def testConvolutionalPullingIndicesPerChannelCA(self,
233+
data_format,
234+
clustering_centroids,
235+
weight,
236+
expected_output):
237+
"""Verifies that get_pulling_indices function works as expected."""
238+
clustering_centroids = tf.Variable(clustering_centroids, dtype=tf.float32)
239+
clustering_algo = clustering_registry.ClusteringAlgorithmPerChannel(
240+
clustering_centroids, GradientAggregation.SUM, data_format
241+
)
242+
weight = tf.convert_to_tensor(weight)
243+
pulling_indices = clustering_algo.get_pulling_indices(weight)
244+
245+
# check that pulling_indices has the same shape as weight
246+
self.assertEqual(pulling_indices.shape, weight.shape)
247+
self.assertAllEqual(pulling_indices, expected_output)
248+
203249
@parameterized.parameters(
204250
(GradientAggregation.AVG,
205-
[[[[0], [0]], [[0], [1]]],
206-
[[[0], [2]], [[1], [0]]]], [[1, 1, 0], [1, 1, 1]]),
251+
[[[[0, 0, 0]], [[1, 1, 0]]],
252+
[[[1, 0, 1]], [[0, 1, 0]]]],
253+
[[1, 1], [1, 1], [1, 1]]),
207254
(GradientAggregation.SUM,
208-
[[[[0], [0]], [[0], [1]]],
209-
[[[0], [2]], [[1], [0]]]], [[3, 1, 0], [2, 1, 1]])
255+
[[[[0, 0, 0]], [[1, 1, 0]]],
256+
[[[1, 0, 1]], [[0, 1, 0]]]],
257+
[[2, 2], [2, 2], [3, 1]])
210258
)
211-
def testConvolutionalPerChannelCAGrad(self,
259+
def testConvolutionalPerChannelCAGradChannelsLast(self,
212260
cluster_gradient_aggregation,
213261
pulling_indices,
214262
expected_grad_centroids):
215-
"""Verifies that the gradients of convolutional layer work as expected."""
263+
"""Verifies that the gradients of convolutional layer works."""
216264

217-
clustering_centroids = tf.Variable([[0., 1, 2], [3, 4, 5]],
265+
clustering_centroids = tf.Variable([[1, 2], [3, 4], [5, 6]],
218266
dtype=tf.float32)
219-
weight = tf.constant([[[[0.1, 3.0]], [[0.2, 0.1]]],
220-
[[[0.1, 3.0]], [[0.2, 0.1]]]])
267+
weight = tf.constant([[[[1.1, 3.2, 5.2]], [[2.0, 4.1, 5.2]]],
268+
[[[2.1, 2., 6.1]], [[1., 5., 5.]]]])
221269

222-
clustering_algo = clustering_registry.PerChannelCA(
223-
clustering_centroids, cluster_gradient_aggregation
270+
clustering_algo = clustering_registry.ClusteringAlgorithmPerChannel(
271+
clustering_centroids, cluster_gradient_aggregation, "channels_last"
224272
)
225273
self._check_gradients_clustered_weight(
226274
clustering_algo,
@@ -229,6 +277,37 @@ def testConvolutionalPerChannelCAGrad(self,
229277
expected_grad_centroids,
230278
)
231279

280+
@parameterized.parameters(
281+
(GradientAggregation.AVG,
282+
[[[[0, 0], [1, 0]], [[0, 0], [1, 1]],
283+
[[0, 0], [1, 1]], [[1, 1], [0, 0]]]],
284+
[[1, 1], [1, 1], [1, 1], [1, 1]]),
285+
(GradientAggregation.SUM,
286+
[[[[0, 0], [1, 0]], [[0, 0], [1, 1]],
287+
[[0, 0], [1, 1]], [[1, 1], [0, 0]]]],
288+
[[3, 1], [2, 2], [2, 2], [2, 2]])
289+
)
290+
def testConvolutionalPerChannelCAGradChannelsFirst(self,
291+
cluster_gradient_aggregation,
292+
pulling_indices,
293+
expected_grad_centroids):
294+
"""Verifies that the gradients of convolutional layer works."""
295+
296+
clustering_centroids = tf.Variable([[1, 2], [3, 4], [4, 5], [6, 7]],
297+
dtype=tf.float32)
298+
weight = tf.constant([[[[0.1, 1.5], [2.0, 1.1]],
299+
[[0., 3.5], [4.4, 4.]], [[4.1, 4.2], [5.3, 6.]],
300+
[[7., 7.1], [6.1, 5.8]]]])
301+
302+
clustering_algo = clustering_registry.ClusteringAlgorithmPerChannel(
303+
clustering_centroids, cluster_gradient_aggregation, "channels_first"
304+
)
305+
self._check_gradients_clustered_weight(
306+
clustering_algo,
307+
weight,
308+
pulling_indices,
309+
expected_grad_centroids,
310+
)
232311

233312
class CustomLayer(layers.Layer):
234313
"""A custom non-clusterable layer class."""

0 commit comments

Comments
 (0)