Skip to content

Commit 537cefb

Browse files
Merge pull request #920 from wwwind:clustering_per_channel_cqat
PiperOrigin-RevId: 430602869
2 parents 1ba8dbb + a9cd0b0 commit 537cefb

File tree

7 files changed

+407
-41
lines changed

7 files changed

+407
-41
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ def testMHA(self):
562562

563563
class ClusterPerChannelIntegrationTest(tf.test.TestCase,
564564
parameterized.TestCase):
565-
"""Integration tests for per-channel clustering of Conv2D layer."""
565+
"""Integration test for cluster_per_channel for Conv2D layer."""
566566

567567
def setUp(self):
568568
super(ClusterPerChannelIntegrationTest, self).setUp()
@@ -600,15 +600,28 @@ def testPerChannel(self):
600600
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
601601
metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")])
602602
clustered_model.fit(
603-
self.x_train, self.y_train, epochs=1, batch_size=100, verbose=1)
603+
self.x_train, self.y_train, epochs=2, batch_size=100, verbose=1)
604604

605605
stripped_model = cluster.strip_clustering(clustered_model)
606606

607607
layer_conv2d = stripped_model.layers[2]
608+
layer_conv2d_kernel_weight = None
608609
for weight in layer_conv2d.weights:
609610
if "kernel" in weight.name:
610-
nr_unique_weights = len(np.unique(weight.numpy()))
611-
assert nr_unique_weights == self.nr_of_clusters*self.num_channels
611+
layer_conv2d_kernel_weight = weight
612+
self.assertIsNotNone(layer_conv2d_kernel_weight)
613+
nr_unique_weights = len(np.unique(layer_conv2d_kernel_weight.numpy()))
614+
self.assertEqual(nr_unique_weights, self.nr_of_clusters*self.num_channels)
615+
616+
# The above check is too general.
617+
# We need to check that we actually have nr_of_clusters per channel.
618+
# If more general case passed, then we do tighter check.
619+
# Note that we assume that data_format is NHWC.
620+
for i in range(self.num_channels):
621+
nr_unique_weights_per_channel = len(
622+
np.unique(layer_conv2d_kernel_weight[:, :, :, i]))
623+
self.assertEqual(nr_unique_weights_per_channel, self.nr_of_clusters)
624+
612625

613626
if __name__ == "__main__":
614627
test.main()

tensorflow_model_optimization/python/core/clustering/keras/clustering_algorithm.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(
5353
clusters centroids.
5454
cluster_gradient_aggregation: An enum that specify the aggregation method
5555
of the cluster gradient.
56-
data_format: To be used in Per-Channel clustering to ensure the weight
56+
data_format: To be used in cluster_per_channel to ensure the weight
5757
kernel is permuted properly when updating the weights and calculating
5858
gradients
5959
"""
@@ -194,10 +194,24 @@ def get_clustered_weight(self, pulling_indices, original_weight):
194194
return clustered_weight
195195

196196

197-
class PerChannelCA(ClusteringAlgorithm):
197+
class ClusteringAlgorithmPerChannel(ClusteringAlgorithm):
198198
"""Class for Per-channel clustering of Conv2D layers."""
199199

200200
def get_pulling_indices(self, weight):
201+
"""Returns indices of closest cluster centroids.
202+
203+
This function is based on the function get_pulling_indices
204+
of the base class ClusteringAlgorithm. We apply each per
205+
channel of the convolutional layer.
206+
207+
Args:
208+
weight: ND array of weights. For each weight in this array the closest
209+
cluster centroids is found.
210+
211+
Returns:
212+
ND array of the same shape as `weight` parameter of the type
213+
tf.int32. The returned array contain weight lookup indices.
214+
"""
201215
channel_indices = []
202216

203217
num_channels = (weight.shape[1] if self.data_format == "channels_first"
@@ -213,7 +227,13 @@ def get_pulling_indices(self, weight):
213227

214228
channel_indices.append(pulling_indices)
215229

216-
return tf.convert_to_tensor(channel_indices)
230+
pulling_indices = tf.convert_to_tensor(channel_indices)
231+
pulling_indices = tf.transpose(
232+
pulling_indices,
233+
perm=(1, 0, 2, 3) if self.data_format == "channels_first" else
234+
(1, 2, 3, 0))
235+
236+
return pulling_indices
217237

218238
def get_clustered_weight(self, pulling_indices, original_weight):
219239
"""Returns clustered weights with custom gradients.
@@ -240,6 +260,16 @@ def get_clustered_weight(self, pulling_indices, original_weight):
240260
original_weight.shape[1]
241261
if self.data_format == "channels_first" else original_weight.shape[-1])
242262

263+
# In case of channels_last, we have NHWC.
264+
# In case of channels_first, we have NCHW.
265+
# We need to transpose the tensor, so C is the first dimension
266+
# and then we could loop over channels
267+
pulling_indices = (
268+
tf.transpose(
269+
pulling_indices,
270+
perm=(1, 0, 2, 3) if self.data_format == "channels_first" else
271+
(3, 0, 1, 2)))
272+
243273
if self.cluster_gradient_aggregation == GradientAggregation.SUM:
244274
cluster_centroids = self.cluster_centroids
245275
elif self.cluster_gradient_aggregation == GradientAggregation.AVG:

tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
layers = tf.keras.layers
2323
ClusteringAlgorithm = clustering_algorithm.ClusteringAlgorithm
24-
PerChannelCA = clustering_algorithm.PerChannelCA
24+
ClusteringAlgorithmPerChannel = clustering_algorithm.ClusteringAlgorithmPerChannel
2525

2626

2727
class ClusteringLookupRegistry(object):
@@ -43,7 +43,7 @@ def get_clustering_impl(cls, layer, weight_name, cluster_per_channel=False):
4343
# Per-channel clustering is only applied if the layer is a Conv2D,
4444
# ignored otherwise
4545
if cluster_per_channel and isinstance(layer, tf.keras.layers.Conv2D):
46-
return PerChannelCA
46+
return ClusteringAlgorithmPerChannel
4747

4848
# Clusterable layer could provide own implementation of get_pulling_indices
4949
if (issubclass(layer.__class__, clusterable_layer.ClusterableLayer) and

tensorflow_model_optimization/python/core/clustering/keras/clustering_registry_test.py

Lines changed: 93 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -184,44 +184,111 @@ def testConvolutionalWeightsCA(self, clustering_centroids, pulling_indices,
184184
self._check_pull_values(clustering_algo, pulling_indices, expected_output)
185185

186186
@parameterized.parameters(
187-
([[0., 1, 2], [3, 4, 5]],
188-
[[[[0], [0]], [[0], [1]]],
189-
[[[0], [2]], [[1], [0]]]],
190-
[[[[0], [0]], [[0], [0]]],
191-
[[[0], [0]], [[1], [1]]]]))
192-
def testConvolutionalWeightsPerChannelCA(self,
187+
(
188+
'channels_last',
189+
[[1, 2], [3, 4], [5, 6]], # 3 channels and 2 cluster per channel
190+
# pulling indices has shape (2, 2, 1, 3)
191+
[[[[0, 1, 0]], [[0, 1, 1]]], [[[1, 0, 1]], [[0, 1, 0]]]],
192+
[[[[1, 4, 5]], [[1, 4, 6]]], [[[2, 3, 6]], [[1, 4, 5]]]]),
193+
(
194+
'channels_first',
195+
[[1, 2], [3, 4], [4, 5], [6, 7]
196+
], # 4 channels and 2 clusters per channel
197+
# pulling indices has shape (1, 4, 2, 2)
198+
[[[[0, 1], [1, 1]], [[0, 0], [0, 1]], [[1, 0], [0, 0]],
199+
[[1, 1], [0, 0]]]],
200+
[[[[1, 2], [2, 2]], [[3, 3], [3, 4]], [[5, 4], [4, 4]],
201+
[[7, 7], [6, 6]]]]))
202+
def testConvolutionalWeightsPerChannelCA(self, data_format,
193203
clustering_centroids,
194204
pulling_indices,
195205
expected_output):
196-
"""Verifies that PerChannelCA works as expected."""
206+
"""Verifies that get_clustered_weight function works as expected."""
197207
clustering_centroids = tf.Variable(clustering_centroids, dtype=tf.float32)
198-
clustering_algo = clustering_registry.PerChannelCA(
199-
clustering_centroids, GradientAggregation.SUM
208+
clustering_algo = clustering_registry.ClusteringAlgorithmPerChannel(
209+
clustering_centroids, GradientAggregation.SUM, data_format
200210
)
211+
# Note that clustered_weights has the same shape as pulling_indices,
212+
# because they are defined inside of the check function.
201213
self._check_pull_values(clustering_algo, pulling_indices, expected_output)
202214

203215
@parameterized.parameters(
204-
(GradientAggregation.AVG,
205-
[[[[0], [0]], [[0], [1]]],
206-
[[[0], [2]], [[1], [0]]]], [[1, 1, 0], [1, 1, 1]]),
207-
(GradientAggregation.SUM,
208-
[[[[0], [0]], [[0], [1]]],
209-
[[[0], [2]], [[1], [0]]]], [[3, 1, 0], [2, 1, 1]])
210-
)
211-
def testConvolutionalPerChannelCAGrad(self,
212-
cluster_gradient_aggregation,
213-
pulling_indices,
214-
expected_grad_centroids):
215-
"""Verifies that the gradients of convolutional layer work as expected."""
216+
(
217+
'channels_last',
218+
[[1, 2], [3, 4], [5, 6]], # 3 channels and 2 cluster per channel
219+
# weight has shape (2, 2, 1, 3)
220+
[[[[1.1, 3.2, 5.2]], [[2.0, 4.1, 5.2]]],
221+
[[[2.1, 2., 6.1]], [[1., 5., 5.]]]],
222+
# expected pulling indices
223+
[[[[0, 0, 0]], [[1, 1, 0]]], [[[1, 0, 1]], [[0, 1, 0]]]]),
224+
(
225+
'channels_first',
226+
# 4 channels and 2 clusters per channel
227+
[[1, 2], [3, 4], [4, 5], [6, 7]],
228+
# weight has shape (1, 4, 2, 2)
229+
[[[[0.1, 1.5], [2.0, 1.1]], [[0., 3.5], [4.4, 4.]],
230+
[[4.1, 4.2], [5.3, 6.]], [[7., 7.1], [6.1, 5.8]]]],
231+
# expected pulling indices
232+
[[[[0, 0], [1, 0]], [[0, 0], [1, 1]], [[0, 0], [1, 1]],
233+
[[1, 1], [0, 0]]]]))
234+
def testConvolutionalPullingIndicesPerChannelCA(self, data_format,
235+
clustering_centroids, weight,
236+
expected_output):
237+
"""Verifies that get_pulling_indices function works as expected."""
238+
clustering_centroids = tf.Variable(clustering_centroids, dtype=tf.float32)
239+
clustering_algo = clustering_registry.ClusteringAlgorithmPerChannel(
240+
clustering_centroids, GradientAggregation.SUM, data_format
241+
)
242+
weight = tf.convert_to_tensor(weight)
243+
pulling_indices = clustering_algo.get_pulling_indices(weight)
216244

217-
clustering_centroids = tf.Variable([[0., 1, 2], [3, 4, 5]],
245+
# check that pulling_indices has the same shape as weight
246+
self.assertEqual(pulling_indices.shape, weight.shape)
247+
self.assertAllEqual(pulling_indices, expected_output)
248+
249+
@parameterized.parameters(
250+
(GradientAggregation.AVG, [
251+
[[[0, 0, 0]], [[1, 1, 0]]], [[[1, 0, 1]], [[0, 1, 0]]]],
252+
[[1, 1], [1, 1], [1, 1]]),
253+
(GradientAggregation.SUM, [
254+
[[[0, 0, 0]], [[1, 1, 0]]], [[[1, 0, 1]], [[0, 1, 0]]]],
255+
[[2, 2], [2, 2], [3, 1]]))
256+
def testConvolutionalPerChannelCAGradChannelsLast(
257+
self, cluster_gradient_aggregation, pulling_indices,
258+
expected_grad_centroids):
259+
"""Verifies that the gradients of convolutional layer works."""
260+
261+
clustering_centroids = tf.Variable([[1, 2], [3, 4], [5, 6]],
218262
dtype=tf.float32)
219-
weight = tf.constant([[[[0.1, 3.0]], [[0.2, 0.1]]],
220-
[[[0.1, 3.0]], [[0.2, 0.1]]]])
263+
weight = tf.constant([[[[1.1, 3.2, 5.2]], [[2.0, 4.1, 5.2]]],
264+
[[[2.1, 2., 6.1]], [[1., 5., 5.]]]])
221265

222-
clustering_algo = clustering_registry.PerChannelCA(
223-
clustering_centroids, cluster_gradient_aggregation
266+
clustering_algo = clustering_registry.ClusteringAlgorithmPerChannel(
267+
clustering_centroids, cluster_gradient_aggregation, 'channels_last')
268+
self._check_gradients_clustered_weight(
269+
clustering_algo,
270+
weight,
271+
pulling_indices,
272+
expected_grad_centroids,
224273
)
274+
275+
@parameterized.parameters((GradientAggregation.AVG, [
276+
[[[0, 0], [1, 0]], [[0, 0], [1, 1]], [[0, 0], [1, 1]], [[1, 1], [0, 0]]]
277+
], [[1, 1], [1, 1], [1, 1], [1, 1]]), (GradientAggregation.SUM, [
278+
[[[0, 0], [1, 0]], [[0, 0], [1, 1]], [[0, 0], [1, 1]], [[1, 1], [0, 0]]]
279+
], [[3, 1], [2, 2], [2, 2], [2, 2]]))
280+
def testConvolutionalPerChannelCAGradChannelsFirst(
281+
self, cluster_gradient_aggregation, pulling_indices,
282+
expected_grad_centroids):
283+
"""Verifies that the gradients of convolutional layer works."""
284+
285+
clustering_centroids = tf.Variable([[1, 2], [3, 4], [4, 5], [6, 7]],
286+
dtype=tf.float32)
287+
weight = tf.constant([[[[0.1, 1.5], [2.0, 1.1]], [[0., 3.5], [4.4, 4.]],
288+
[[4.1, 4.2], [5.3, 6.]], [[7., 7.1], [6.1, 5.8]]]])
289+
290+
clustering_algo = clustering_registry.ClusteringAlgorithmPerChannel(
291+
clustering_centroids, cluster_gradient_aggregation, 'channels_first')
225292
self._check_gradients_clustered_weight(
226293
clustering_algo,
227294
weight,

tensorflow_model_optimization/python/core/quantization/keras/collaborative_optimizations/cluster_preserve/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ py_strict_library(
3737
deps = [
3838
":cluster_utils",
3939
# tensorflow dep1,
40+
"//tensorflow_model_optimization/python/core/clustering/keras:cluster_config",
4041
"//tensorflow_model_optimization/python/core/clustering/keras:clustering_registry",
4142
"//tensorflow_model_optimization/python/core/quantization/keras:quant_ops",
4243
"//tensorflow_model_optimization/python/core/quantization/keras:quantizers",

0 commit comments

Comments
 (0)