Skip to content

Commit 97aca5b

Browse files
Per-channel Clustering Implementation and Tests
1 parent 1bec520 commit 97aca5b

File tree

10 files changed

+508
-112
lines changed

10 files changed

+508
-112
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,11 @@ def cluster_weights(
128128
to_cluster,
129129
number_of_clusters,
130130
cluster_centroids_init,
131-
preserve_sparsity=False,
132131
**kwargs)
133132

134133

135134
def _cluster_weights(to_cluster, number_of_clusters, cluster_centroids_init,
136-
preserve_sparsity, **kwargs):
135+
preserve_sparsity=False, cluster_per_channel=False, **kwargs):
137136
"""Modifies a keras layer or model to be clustered during training.
138137
139138
This function wraps a keras model or layer with clustering functionality
@@ -158,6 +157,7 @@ def _cluster_weights(to_cluster, number_of_clusters, cluster_centroids_init,
158157
clustering_params = {
159158
'number_of_clusters': 8,
160159
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED,
160+
'cluster_per_channel': False,
161161
'preserve_sparsity': False
162162
}
163163
@@ -170,6 +170,7 @@ def _cluster_weights(to_cluster, number_of_clusters, cluster_centroids_init,
170170
clustering_params = {
171171
'number_of_clusters': 8,
172172
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED,
173+
'cluster_per_channel': False,
173174
'preserve_sparsity': False
174175
}
175176
@@ -202,8 +203,19 @@ def _cluster_weights(to_cluster, number_of_clusters, cluster_centroids_init,
202203
8 unique values will be used in each weight array.
203204
cluster_centroids_init: `tfmot.clustering.keras.CentroidInitialization`
204205
instance that determines how the cluster centroids will be initialized.
206+
cluster_per_channel: optional boolean value that determines whether the
207+
clustering should be applied separately on the individual channels, as
208+
opposed to the whole kernel. Only applicable to Conv2D layers and is
209+
ignored otherwise. The number of clusters in this case would be
210+
num_clusters*num_channels. This is useful for the collaborative
211+
optimization pipeline where clustering is followed by quantization,
212+
since Conv2D is quantized per-channel, so we end up with
213+
num_clusters*num_channels total clusters at the end. Clustering
214+
per-channel from the beginning leads to better accuracy.
205215
preserve_sparsity (experimental): optional boolean value that determines
206216
whether or not sparsity preservation will be enforced during training.
217+
When used along with cluster_per_channel flag above, the zero centroid
218+
is treated separately and maintained individually for each channel.
207219
**kwargs: Additional keyword arguments to be passed to the keras layer.
208220
Ignored when to_cluster is not a keras layer.
209221
@@ -255,7 +267,8 @@ def _add_clustering_wrapper(layer):
255267

256268
return cluster_wrapper.ClusterWeights(layer, number_of_clusters,
257269
cluster_centroids_init,
258-
preserve_sparsity, **kwargs)
270+
preserve_sparsity,
271+
cluster_per_channel, **kwargs)
259272

260273
def _wrap_list(layers):
261274
output = []

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,5 +579,52 @@ def testMHA(self):
579579
nr_unique_weights = len(np.unique(weight.numpy()))
580580
assert nr_unique_weights == self.nr_of_clusters
581581

582+
class ClusterPerChannelIntegrationTest(tf.test.TestCase, parameterized.TestCase):
583+
"""Integration tests for per-channel clustering of Conv2D layer."""
584+
585+
def setUp(self):
586+
self.x_train = np.random.uniform(size=(500, 32, 32))
587+
self.y_train = np.random.randint(low=0, high=1024, size=(500,))
588+
589+
self.nr_of_clusters = 4
590+
self.num_channels = 12
591+
self.params_clustering = {
592+
"number_of_clusters": self.nr_of_clusters,
593+
"cluster_centroids_init": CentroidInitialization.KMEANS_PLUS_PLUS,
594+
"cluster_per_channel": True
595+
}
596+
597+
def _get_model(self):
598+
"""Returns functional model with Conv2D layer."""
599+
inp = tf.keras.layers.Input(shape=(32,32), batch_size=100)
600+
x = tf.keras.layers.Reshape((32, 32, 1))(inp)
601+
x = tf.keras.layers.Conv2D(
602+
filters=self.num_channels, kernel_size=(3, 3),
603+
activation='relu')(x)
604+
x = tf.keras.layers.MaxPool2D(2, 2)(x)
605+
out = tf.keras.layers.Flatten()(x)
606+
model = tf.keras.Model(inputs=inp, outputs=out)
607+
return model
608+
609+
@keras_parameterized.run_all_keras_modes
610+
def testPerChannel(self):
611+
model = self._get_model()
612+
613+
clustered_model = cluster.cluster_weights(model, **self.params_clustering)
614+
615+
clustered_model.compile(
616+
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
617+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
618+
metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')])
619+
clustered_model.fit(self.x_train, self.y_train, epochs=1, batch_size=100, verbose=1)
620+
621+
stripped_model = cluster.strip_clustering(clustered_model)
622+
623+
layerConv2D = stripped_model.layers[2]
624+
for weight in layerConv2D.weights:
625+
if 'kernel' in weight.name:
626+
nr_unique_weights = len(np.unique(weight.numpy()))
627+
assert nr_unique_weights == self.nr_of_clusters*self.num_channels
628+
582629
if __name__ == "__main__":
583630
test.main()

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py

Lines changed: 76 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(self,
5858
number_of_clusters,
5959
cluster_centroids_init=CentroidInitialization.KMEANS_PLUS_PLUS,
6060
preserve_sparsity=False,
61+
cluster_per_channel=False,
6162
cluster_gradient_aggregation=GradientAggregation.SUM,
6263
**kwargs):
6364
if not isinstance(layer, Layer):
@@ -101,6 +102,17 @@ def __init__(self,
101102
# The number of cluster centroids
102103
self.number_of_clusters = number_of_clusters
103104

105+
# Whether to cluster Conv2D kernels per-channel.
106+
# In case the layer isn't a Conv2D, this isn't
107+
# applicable
108+
self.cluster_per_channel = (
109+
cluster_per_channel if isinstance(layer, tf.keras.layers.Conv2D)
110+
else False)
111+
112+
# Number of channels in a Conv2D layer, to be
113+
# used the case of per-channel clustering
114+
self.num_channels = None
115+
104116
# Whether to apply sparsity preservation or not
105117
self.preserve_sparsity = preserve_sparsity
106118

@@ -137,12 +149,31 @@ def __init__(self,
137149
hasattr(layer, '_batch_input_shape')):
138150
self._batch_input_shape = self.layer._batch_input_shape
139151

152+
# In the case of Conv2D layer, the data_format
153+
# needs to be preserved to be used for per-channel
154+
# clustering
155+
if hasattr(layer, 'data_format'):
156+
self.data_format = self.layer.data_format
157+
else:
158+
self.data_format = None
159+
140160
# Save the input shape specified in the build
141161
self.build_input_shape = None
142162

143163
def _make_layer_name(self, layer):
144164
return '{}_{}'.format('cluster', layer.name)
145165

166+
def _get_zero_idx_mask(self, centroids, zero_cluster):
167+
zero_idx_mask = (tf.cast(tf.math.not_equal(centroids,
168+
zero_cluster),
169+
dtype=tf.float32))
170+
return zero_idx_mask
171+
172+
def _get_zero_centroid(self, centroids, zero_idx_mask):
173+
zero_centroid = tf.math.multiply(centroids,
174+
zero_idx_mask)
175+
return zero_centroid
176+
146177
def get_weight_from_layer(self, weight_name):
147178
return getattr(self.layer, weight_name)
148179

@@ -173,15 +204,28 @@ def build(self, input_shape):
173204
i for i, w in enumerate(self.layer.weights) if w is original_weight)
174205
self.position_original_weights[position_original_weight] = weight_name
175206

207+
# In the case of per-channel clustering, the number of channels,
208+
# per-channel number of clusters, as well as the overall number
209+
# of clusters all need to be preserved in the wrapper.
210+
if self.cluster_per_channel:
211+
self.num_channels = (
212+
original_weight.shape[1] if self.data_format == "channels_first"
213+
else original_weight.shape[-1])
214+
215+
centroid_init_factory = clustering_centroids.CentroidsInitializerFactory
216+
centroid_init = centroid_init_factory.get_centroid_initializer(
217+
self.cluster_centroids_init)(
218+
weight, self.number_of_clusters,
219+
self.cluster_per_channel,
220+
self.num_channels,
221+
self.preserve_sparsity)
222+
176223
# Init the cluster centroids
177-
cluster_centroids = (
178-
clustering_centroids.CentroidsInitializerFactory
179-
.get_centroid_initializer(self.cluster_centroids_init)(
180-
weight, self.number_of_clusters,
181-
self.preserve_sparsity).get_cluster_centroids())
224+
cluster_centroids = (centroid_init.get_cluster_centroids())
225+
182226
self.cluster_centroids[weight_name] = self.add_weight(
183227
'{}{}'.format('cluster_centroids_', weight_name),
184-
shape=(self.number_of_clusters,),
228+
shape=(cluster_centroids.shape),
185229
dtype=weight.dtype,
186230
trainable=True,
187231
initializer=tf.keras.initializers.Constant(value=cluster_centroids))
@@ -198,10 +242,11 @@ def build(self, input_shape):
198242
weight_name_no_index = weight_name
199243
self.clustering_algorithms[weight_name] = (
200244
clustering_registry.ClusteringLookupRegistry().get_clustering_impl(
201-
self.layer, weight_name_no_index)
245+
self.layer, weight_name_no_index, self.cluster_per_channel)
202246
(
203247
clusters_centroids=self.cluster_centroids[weight_name],
204248
cluster_gradient_aggregation=self.cluster_gradient_aggregation,
249+
data_format=self.data_format,
205250
))
206251

207252
# Init the pulling_indices (weights associations)
@@ -233,18 +278,27 @@ def update_clustered_weights_associations(self):
233278
):
234279

235280
if self.preserve_sparsity:
236-
# Set the smallest centroid to zero to force sparsity
237-
# and avoid extra cluster from forming
238-
zero_idx_mask = (
239-
tf.cast(
240-
tf.math.not_equal(
241-
self.cluster_centroids[weight_name],
242-
self.cluster_centroids[weight_name][
243-
self.zero_idx[weight_name]]),
244-
dtype=tf.float32))
245-
self.cluster_centroids[weight_name].assign(
246-
tf.math.multiply(self.cluster_centroids[weight_name],
247-
zero_idx_mask))
281+
# In the case of per-channel clustering, sparsity
282+
# needs to be preserved per-channel
283+
if self.cluster_per_channel:
284+
for channel in range(self.num_channels):
285+
zero_idx_mask = (
286+
self._get_zero_idx_mask(self.cluster_centroids[weight_name][channel],
287+
self.cluster_centroids[weight_name][channel][
288+
self.zero_idx[weight_name][channel]]))
289+
self.cluster_centroids[weight_name][channel].assign(
290+
self._get_zero_centroid(self.cluster_centroids[weight_name][channel],
291+
zero_idx_mask))
292+
else:
293+
# Set the smallest centroid to zero to force sparsity
294+
# and avoid extra cluster from forming
295+
zero_idx_mask = self._get_zero_idx_mask(self.cluster_centroids[weight_name],
296+
self.cluster_centroids[weight_name][
297+
self.zero_idx[weight_name]])
298+
self.cluster_centroids[weight_name].assign(
299+
self._get_zero_centroid(self.cluster_centroids[weight_name],
300+
zero_idx_mask))
301+
248302
# During training, the original zero weights can drift slightly.
249303
# We want to prevent this by forcing them to stay zero at the places
250304
# where they were originally zero to begin with.
@@ -284,6 +338,7 @@ def get_config(self):
284338
'cluster_centroids_init': self.cluster_centroids_init,
285339
'preserve_sparsity': self.preserve_sparsity,
286340
'cluster_gradient_aggregation': self.cluster_gradient_aggregation,
341+
'cluster_per_channel': self.cluster_per_channel,
287342
**base_config
288343
}
289344
return config
@@ -296,12 +351,14 @@ def from_config(cls, config, custom_objects=None):
296351
cluster_centroids_init = config.pop('cluster_centroids_init')
297352
preserve_sparsity = config.pop('preserve_sparsity')
298353
cluster_gradient_aggregation = config.pop('cluster_gradient_aggregation')
354+
cluster_per_channel = config.pop('cluster_per_channel')
299355

300356
config['number_of_clusters'] = number_of_clusters
301357
config['cluster_centroids_init'] = cluster_config.CentroidInitialization(
302358
cluster_centroids_init)
303359
config['preserve_sparsity'] = preserve_sparsity
304360
config['cluster_gradient_aggregation'] = cluster_gradient_aggregation
361+
config['cluster_per_channel'] = cluster_per_channel
305362

306363
layer = tf.keras.layers.deserialize(
307364
config.pop('layer'), custom_objects=custom_objects)

0 commit comments

Comments
 (0)