Skip to content

Commit 78fbab2

Browse files
Merge pull request #871 from MohamedNourArm:toupstream/per_channel_clustering
PiperOrigin-RevId: 415427526
2 parents 94374e5 + 97aca5b commit 78fbab2

File tree

11 files changed

+575
-144
lines changed

11 files changed

+575
-144
lines changed

tensorflow_model_optimization/python/core/clustering/keras/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ py_strict_test(
139139
":cluster_config",
140140
":clustering_centroids",
141141
# absl/testing:parameterized dep1,
142+
# numpy dep1,
142143
# tensorflow dep1,
143144
],
144145
)

tensorflow_model_optimization/python/core/clustering/keras/cluster.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -124,16 +124,16 @@ def cluster_weights(
124124
ValueError: if the keras layer is unsupported, or the keras model contains
125125
an unsupported layer.
126126
"""
127-
return _cluster_weights(
128-
to_cluster,
129-
number_of_clusters,
130-
cluster_centroids_init,
131-
preserve_sparsity=False,
132-
**kwargs)
127+
return _cluster_weights(to_cluster, number_of_clusters,
128+
cluster_centroids_init, **kwargs)
133129

134130

135-
def _cluster_weights(to_cluster, number_of_clusters, cluster_centroids_init,
136-
preserve_sparsity, **kwargs):
131+
def _cluster_weights(to_cluster,
132+
number_of_clusters,
133+
cluster_centroids_init,
134+
preserve_sparsity=False,
135+
cluster_per_channel=False,
136+
**kwargs):
137137
"""Modifies a keras layer or model to be clustered during training.
138138
139139
This function wraps a keras model or layer with clustering functionality
@@ -158,6 +158,7 @@ def _cluster_weights(to_cluster, number_of_clusters, cluster_centroids_init,
158158
clustering_params = {
159159
'number_of_clusters': 8,
160160
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED,
161+
'cluster_per_channel': False,
161162
'preserve_sparsity': False
162163
}
163164
@@ -170,6 +171,7 @@ def _cluster_weights(to_cluster, number_of_clusters, cluster_centroids_init,
170171
clustering_params = {
171172
'number_of_clusters': 8,
172173
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED,
174+
'cluster_per_channel': False,
173175
'preserve_sparsity': False
174176
}
175177
@@ -204,6 +206,17 @@ def _cluster_weights(to_cluster, number_of_clusters, cluster_centroids_init,
204206
instance that determines how the cluster centroids will be initialized.
205207
preserve_sparsity (experimental): optional boolean value that determines
206208
whether or not sparsity preservation will be enforced during training.
209+
When used along with cluster_per_channel flag below, the zero centroid
210+
is treated separately and maintained individually for each channel.
211+
cluster_per_channel: optional boolean value that determines whether the
212+
clustering should be applied separately on the individual channels, as
213+
opposed to the whole kernel. Only applicable to Conv2D layers and is
214+
ignored otherwise. The number of clusters in this case would be
215+
num_clusters*num_channels. This is useful for the collaborative
216+
optimization pipeline where clustering is followed by quantization,
217+
since Conv2D is quantized per-channel, so we end up with
218+
num_clusters*num_channels total clusters at the end. Clustering
219+
per-channel from the beginning leads to better accuracy.
207220
**kwargs: Additional keyword arguments to be passed to the keras layer.
208221
Ignored when to_cluster is not a keras layer.
209222
@@ -255,7 +268,8 @@ def _add_clustering_wrapper(layer):
255268

256269
return cluster_wrapper.ClusterWeights(layer, number_of_clusters,
257270
cluster_centroids_init,
258-
preserve_sparsity, **kwargs)
271+
preserve_sparsity,
272+
cluster_per_channel, **kwargs)
259273

260274
def _wrap_list(layers):
261275
output = []
@@ -310,11 +324,11 @@ def _strip_clustering_wrapper(layer):
310324
layer, input_tensors=None, clone_function=_strip_clustering_wrapper)
311325

312326
elif isinstance(layer, cluster_wrapper.ClusterWeightsMHA):
313-
# Update cluster associations in order to get the latest weights
314-
layer.update_clustered_weights_associations()
327+
# Update cluster associations in order to get the latest weights
328+
layer.update_clustered_weights_associations()
315329

316-
# In case of MHA layer, use the overloaded implementation
317-
return layer.strip_clustering()
330+
# In case of MHA layer, use the overloaded implementation
331+
return layer.strip_clustering()
318332

319333
elif isinstance(layer, cluster_wrapper.ClusterWeights):
320334
# Update cluster associations in order to get the latest weights

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -538,23 +538,26 @@ def testClusterStackedRNNCells(self):
538538
expected_unique_weights=self.params_clustering["number_of_clusters"],
539539
)
540540

541+
541542
class ClusterMHAIntegrationTest(tf.test.TestCase, parameterized.TestCase):
542543
"""Integration tests for clustering MHA layer."""
543544

544545
def setUp(self):
546+
super(ClusterMHAIntegrationTest, self).setUp()
545547
self.x_train = np.random.uniform(size=(500, 32, 32))
546548
self.y_train = np.random.randint(low=0, high=1024, size=(500,))
547549

548550
self.nr_of_clusters = 16
549551
self.params_clustering = {
550-
"number_of_clusters": self.nr_of_clusters,
551-
"cluster_centroids_init": CentroidInitialization.KMEANS_PLUS_PLUS,
552+
"number_of_clusters": self.nr_of_clusters,
553+
"cluster_centroids_init": CentroidInitialization.KMEANS_PLUS_PLUS,
552554
}
553555

554556
def _get_model(self):
555557
"""Returns functional model with MHA layer."""
556-
inp = tf.keras.layers.Input(shape=(32,32), batch_size=100)
557-
x = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=16)(query=inp, value=inp)
558+
inp = tf.keras.layers.Input(shape=(32, 32), batch_size=100)
559+
x = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=16)(
560+
query=inp, value=inp)
558561
out = tf.keras.layers.Flatten()(x)
559562
model = tf.keras.Model(inputs=inp, outputs=out)
560563
return model
@@ -566,18 +569,70 @@ def testMHA(self):
566569
clustered_model = cluster.cluster_weights(model, **self.params_clustering)
567570

568571
clustered_model.compile(
569-
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
570-
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
571-
metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')])
572-
clustered_model.fit(self.x_train, self.y_train, epochs=1, batch_size=100, verbose=1)
572+
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
573+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
574+
metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")])
575+
clustered_model.fit(
576+
self.x_train, self.y_train, epochs=1, batch_size=100, verbose=1)
573577

574578
stripped_model = cluster.strip_clustering(clustered_model)
575579

576-
layerMHA = stripped_model.layers[1]
577-
for weight in layerMHA.weights:
578-
if 'kernel' in weight.name:
580+
layer_mha = stripped_model.layers[1]
581+
for weight in layer_mha.weights:
582+
if "kernel" in weight.name:
579583
nr_unique_weights = len(np.unique(weight.numpy()))
580584
assert nr_unique_weights == self.nr_of_clusters
581585

586+
587+
class ClusterPerChannelIntegrationTest(tf.test.TestCase,
588+
parameterized.TestCase):
589+
"""Integration tests for per-channel clustering of Conv2D layer."""
590+
591+
def setUp(self):
592+
super(ClusterPerChannelIntegrationTest, self).setUp()
593+
self.x_train = np.random.uniform(size=(500, 32, 32))
594+
self.y_train = np.random.randint(low=0, high=1024, size=(500,))
595+
596+
self.nr_of_clusters = 4
597+
self.num_channels = 12
598+
self.params_clustering = {
599+
"number_of_clusters": self.nr_of_clusters,
600+
"cluster_centroids_init": CentroidInitialization.KMEANS_PLUS_PLUS,
601+
"cluster_per_channel": True
602+
}
603+
604+
def _get_model(self):
605+
"""Returns functional model with Conv2D layer."""
606+
inp = tf.keras.layers.Input(shape=(32, 32), batch_size=100)
607+
x = tf.keras.layers.Reshape((32, 32, 1))(inp)
608+
x = tf.keras.layers.Conv2D(
609+
filters=self.num_channels, kernel_size=(3, 3),
610+
activation="relu")(x)
611+
x = tf.keras.layers.MaxPool2D(2, 2)(x)
612+
out = tf.keras.layers.Flatten()(x)
613+
model = tf.keras.Model(inputs=inp, outputs=out)
614+
return model
615+
616+
@keras_parameterized.run_all_keras_modes
617+
def testPerChannel(self):
618+
model = self._get_model()
619+
620+
clustered_model = cluster.cluster_weights(model, **self.params_clustering)
621+
622+
clustered_model.compile(
623+
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
624+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
625+
metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")])
626+
clustered_model.fit(
627+
self.x_train, self.y_train, epochs=1, batch_size=100, verbose=1)
628+
629+
stripped_model = cluster.strip_clustering(clustered_model)
630+
631+
layer_conv2d = stripped_model.layers[2]
632+
for weight in layer_conv2d.weights:
633+
if "kernel" in weight.name:
634+
nr_unique_weights = len(np.unique(weight.numpy()))
635+
assert nr_unique_weights == self.nr_of_clusters*self.num_channels
636+
582637
if __name__ == "__main__":
583638
test.main()

0 commit comments

Comments
 (0)