Skip to content

Commit d53ff08

Browse files
Merge pull request #845 from wwwind:mha_clustering_support
PiperOrigin-RevId: 402726607
2 parents f3cbc55 + 1e247de commit d53ff08

File tree

6 files changed

+215
-0
lines changed

6 files changed

+215
-0
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,14 @@ def _add_clustering_wrapper(layer):
244244
preserve_sparsity,
245245
**kwargs,
246246
)
247+
if isinstance(layer, tf.keras.layers.MultiHeadAttention):
248+
return cluster_wrapper.ClusterWeightsMHA(
249+
layer,
250+
number_of_clusters,
251+
cluster_centroids_init,
252+
preserve_sparsity,
253+
**kwargs,
254+
)
247255

248256
return cluster_wrapper.ClusterWeights(layer, number_of_clusters,
249257
cluster_centroids_init,
@@ -301,6 +309,13 @@ def _strip_clustering_wrapper(layer):
301309
return tf.keras.models.clone_model(
302310
layer, input_tensors=None, clone_function=_strip_clustering_wrapper)
303311

312+
elif isinstance(layer, cluster_wrapper.ClusterWeightsMHA):
313+
# Update cluster associations in order to get the latest weights
314+
layer.update_clustered_weights_associations()
315+
316+
# In case of MHA layer, use the overloaded implementation
317+
return layer.strip_clustering()
318+
304319
elif isinstance(layer, cluster_wrapper.ClusterWeights):
305320
# Update cluster associations in order to get the latest weights
306321
layer.update_clustered_weights_associations()

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,5 +534,46 @@ def testClusterStackedRNNCells(self):
534534
expected_unique_weights=self.params_clustering["number_of_clusters"],
535535
)
536536

537+
class ClusterMHAIntegrationTest(tf.test.TestCase, parameterized.TestCase):
538+
"""Integration tests for clustering MHA layer."""
539+
540+
def setUp(self):
541+
self.x_train = np.random.uniform(size=(500, 32, 32))
542+
self.y_train = np.random.randint(low=0, high=1024, size=(500,))
543+
544+
self.nr_of_clusters = 16
545+
self.params_clustering = {
546+
"number_of_clusters": self.nr_of_clusters,
547+
"cluster_centroids_init": CentroidInitialization.KMEANS_PLUS_PLUS,
548+
}
549+
550+
def _get_model(self):
551+
"""Returns functional model with MHA layer."""
552+
inp = tf.keras.layers.Input(shape=(32,32), batch_size=100)
553+
x = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=16)(query=inp, value=inp)
554+
out = tf.keras.layers.Flatten()(x)
555+
model = tf.keras.Model(inputs=inp, outputs=out)
556+
return model
557+
558+
@keras_parameterized.run_all_keras_modes
559+
def testMHA(self):
560+
model = self._get_model()
561+
562+
clustered_model = cluster.cluster_weights(model, **self.params_clustering)
563+
564+
clustered_model.compile(
565+
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
566+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
567+
metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')])
568+
clustered_model.fit(self.x_train, self.y_train, epochs=1, batch_size=100, verbose=1)
569+
570+
stripped_model = cluster.strip_clustering(clustered_model)
571+
572+
layerMHA = stripped_model.layers[1]
573+
for weight in layerMHA.weights:
574+
if 'kernel' in weight.name:
575+
nr_unique_weights = len(np.unique(weight.numpy()))
576+
assert nr_unique_weights == self.nr_of_clusters
577+
537578
if __name__ == "__main__":
538579
test.main()

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# ==============================================================================
1515
"""Keras ClusterWeights wrapper API."""
1616

17+
from operator import attrgetter
18+
1719
import tensorflow as tf
1820

1921
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
@@ -388,3 +390,32 @@ def set_weight_to_layer(self, weight_name, new_weight):
388390
return setattr(return_layer_cell, weight_name_no_index, new_weight)
389391
else:
390392
raise ValueError('No cells in the RNN layer to set weights for.')
393+
394+
395+
class ClusterWeightsMHA(ClusterWeights):
396+
"""This wrapper augments a keras MHA layer so that the weights can be clustered."""
397+
398+
def get_weight_from_layer(self, weight_name):
399+
pre, _, post = weight_name.rpartition('.')
400+
return getattr(getattr(self.layer, pre), post)
401+
402+
def set_weight_to_layer(self, weight_name, new_weight):
403+
pre, _, post = weight_name.rpartition('.')
404+
layer = attrgetter(pre)(self.layer)
405+
setattr(layer, post, new_weight)
406+
407+
def strip_clustering(self):
408+
""" The restore from config is not working for MHA layer, because
409+
weights are not created when the build function is called. Therefore,
410+
original weights have been replaced in the layer."""
411+
for weight_name, original_weight in self.original_clusterable_weights.items():
412+
413+
# Get the clustered weights
414+
clustered_weight = self.get_weight_from_layer(weight_name)
415+
416+
# Re-assign these weights to the original
417+
original_weight.assign(clustered_weight)
418+
setattr(self.layer, weight_name, original_weight)
419+
420+
return self.layer
421+

tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ class ClusteringRegistry(object):
9292
layers.Bidirectional,
9393
])
9494

95+
_SUPPORTED_MHA_LAYERS = {
96+
tf.keras.layers.MultiHeadAttention,
97+
}
98+
9599
@classmethod
96100
def supports(cls, layer):
97101
"""Returns whether the registry supports this layer type.
@@ -121,6 +125,9 @@ def supports(cls, layer):
121125
return False
122126
return True
123127

128+
if layer.__class__ in cls._SUPPORTED_MHA_LAYERS:
129+
return True
130+
124131
return False
125132

126133
def _get_rnn_cells(rnn_layer): # pylint: disable=no-self-argument
@@ -190,8 +197,16 @@ def get_clusterable_weights_rnn_cell(cell, i):
190197
clusterable_weights = get_clusterable_weights_rnn_cell(rnn_cell, 0)
191198
return clusterable_weights
192199

200+
def get_clusterable_weights_mha(): # pylint: disable=missing-docstring
201+
return [('_query_dense.kernel', layer._query_dense.kernel),
202+
('_key_dense.kernel', layer._key_dense.kernel),
203+
('_value_dense.kernel', layer._value_dense.kernel),
204+
('_output_dense.kernel', layer._output_dense.kernel)]
205+
193206
if layer.__class__ in cls._SUPPORTED_RNN_LAYERS:
194207
layer.get_clusterable_weights = get_clusterable_weights_rnn
208+
elif layer.__class__ in cls._SUPPORTED_MHA_LAYERS:
209+
layer.get_clusterable_weights = get_clusterable_weights_mha
195210
else:
196211
layer.get_clusterable_weights = get_clusterable_weights
197212

tensorflow_model_optimization/python/core/clustering/keras/clustering_registry_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,13 @@ def testMakeClusterableRaisesErrorOnRNNLayersUnsupportedCell(self):
512512
with self.assertRaises(ValueError):
513513
ClusterRegistry.make_clusterable(layer)
514514

515+
def testSupportsMultiHeadAttentionLayer(self):
516+
"""
517+
Verifies that ClusterRegistry supports a MultiHeadAttention layer.
518+
"""
519+
layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
520+
self.assertTrue(ClusterRegistry.supports(layer))
521+
ClusterRegistry.make_clusterable(layer)
515522

516523
if __name__ == '__main__':
517524
test.main()
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
# pylint: disable=missing-docstring
16+
"""Train a simple convnet with MultiHeadAttention layer on MNIST dataset
17+
and cluster it.
18+
"""
19+
import tensorflow as tf
20+
import tensorflow_model_optimization as tfmot
21+
22+
import numpy as np
23+
24+
NUMBER_OF_CLUSTERS = 3
25+
26+
# Load MNIST dataset
27+
mnist = tf.keras.datasets.mnist
28+
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
29+
30+
# Normalize the input image so that each pixel value is between 0 to 1.
31+
train_images = train_images / 255.0
32+
test_images = test_images / 255.0
33+
34+
# define model
35+
input = tf.keras.layers.Input(shape=(28, 28))
36+
x = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=16, name="mha")(
37+
query=input, value=input
38+
)
39+
x = tf.keras.layers.Flatten()(x)
40+
out = tf.keras.layers.Dense(10)(x)
41+
model = tf.keras.Model(inputs=input, outputs=out)
42+
43+
# Train the digit classification model
44+
model.compile(
45+
optimizer="adam",
46+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
47+
metrics=["accuracy"],
48+
)
49+
50+
model.fit(
51+
train_images, train_labels, epochs=1, validation_split=0.1,
52+
)
53+
54+
score = model.evaluate(test_images, test_labels, verbose=0)
55+
print('Model test loss:', score[0])
56+
print('Model test accuracy:', score[1])
57+
58+
# Compute end step to finish pruning after 2 epochs.
59+
batch_size = 128
60+
epochs = 1
61+
validation_split = 0.1 # 10% of training set will be used for validation set.
62+
63+
# Define model for clustering
64+
cluster_weights = tfmot.clustering.keras.cluster_weights
65+
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
66+
67+
clustering_params = {
68+
"number_of_clusters": NUMBER_OF_CLUSTERS,
69+
"cluster_centroids_init": CentroidInitialization.KMEANS_PLUS_PLUS,
70+
}
71+
model_for_clustering = cluster_weights(model, **clustering_params)
72+
73+
# `cluster_weights` requires a recompile.
74+
model_for_clustering.compile(
75+
optimizer="adam",
76+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
77+
metrics=["accuracy"],
78+
)
79+
80+
model_for_clustering.fit(
81+
train_images,
82+
train_labels,
83+
batch_size=batch_size,
84+
epochs=epochs,
85+
validation_split=validation_split,
86+
)
87+
88+
score = model_for_clustering.evaluate(test_images, test_labels, verbose=0)
89+
print('Clustered model test loss:', score[0])
90+
print('Clustered model test accuracy:', score[1])
91+
92+
# Strip clustering from the model
93+
clustered_model = tfmot.clustering.keras.strip_clustering(model_for_clustering)
94+
clustered_model.compile(
95+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
96+
optimizer='adam',
97+
metrics=['accuracy'])
98+
99+
score = clustered_model.evaluate(test_images, test_labels, verbose=0)
100+
print('Stripped clustered model test loss:', score[0])
101+
print('Stripped clustered model test accuracy:', score[1])
102+
103+
# Check that numbers of weights for MHA layer is the given number of clusters.
104+
mha_weights = list(filter(lambda x: 'mha' in x.name and 'kernel' in x.name, clustered_model.weights))
105+
for x in mha_weights:
106+
assert len(np.unique(x.numpy())) == NUMBER_OF_CLUSTERS

0 commit comments

Comments
 (0)