Skip to content

Commit 4e660f9

Browse files
committed
Support of MHA for clustering.
Change-Id: I7005bffa70770ce6e36859384dfca351e39318aa
1 parent d36c5a8 commit 4e660f9

File tree

5 files changed

+107
-0
lines changed

5 files changed

+107
-0
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,14 @@ def _add_clustering_wrapper(layer):
244244
preserve_sparsity,
245245
**kwargs,
246246
)
247+
if isinstance(layer, tf.keras.layers.MultiHeadAttention):
248+
return cluster_wrapper.ClusterWeightsMHA(
249+
layer,
250+
number_of_clusters,
251+
cluster_centroids_init,
252+
preserve_sparsity,
253+
**kwargs,
254+
)
247255

248256
return cluster_wrapper.ClusterWeights(layer, number_of_clusters,
249257
cluster_centroids_init,
@@ -301,6 +309,10 @@ def _strip_clustering_wrapper(layer):
301309
return tf.keras.models.clone_model(
302310
layer, input_tensors=None, clone_function=_strip_clustering_wrapper)
303311

312+
elif isinstance(layer, cluster_wrapper.ClusterWeightsMHA):
313+
# In case of MHA layer, use the overloaded implementation
314+
return layer.strip_clustering()
315+
304316
elif isinstance(layer, cluster_wrapper.ClusterWeights):
305317
# Update cluster associations in order to get the latest weights
306318
layer.update_clustered_weights_associations()

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,5 +534,47 @@ def testClusterStackedRNNCells(self):
534534
expected_unique_weights=self.params_clustering["number_of_clusters"],
535535
)
536536

537+
class ClusterMHAIntegrationTest(tf.test.TestCase, parameterized.TestCase):
538+
"""Integration tests for clustering MHA layer."""
539+
540+
def setUp(self):
541+
self.x_train = np.random.uniform(size=(500, 32, 32))
542+
self.y_train = np.random.randint(low=0, high=1024, size=(500,))
543+
544+
self.nr_of_clusters = 16
545+
self.params_clustering = {
546+
"number_of_clusters": self.nr_of_clusters,
547+
"cluster_centroids_init": CentroidInitialization.KMEANS_PLUS_PLUS,
548+
}
549+
550+
def _get_model(self):
551+
"""Returns functional model with MHA layer."""
552+
inp = tf.keras.layers.Input(shape=(32,32), batch_size=100)
553+
x = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=16)(query=inp, value=inp)
554+
out = tf.keras.layers.Flatten()(x)
555+
model = tf.keras.Model(inputs=inp, outputs=out)
556+
return model
557+
558+
@keras_parameterized.run_all_keras_modes
559+
def testMHA(self):
560+
model = self._get_model()
561+
562+
clustered_model = cluster.cluster_weights(model, **self.params_clustering)
563+
564+
clustered_model.compile(
565+
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
566+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
567+
metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')])
568+
clustered_model.run_eagerly = True
569+
clustered_model.fit(self.x_train, self.y_train, epochs=1, batch_size=100, verbose=1)
570+
571+
stripped_model = cluster.strip_clustering(clustered_model)
572+
573+
layerMHA = stripped_model.layers[1]
574+
for weight in layerMHA.weights:
575+
if 'kernel' in weight.name:
576+
nr_unique_weights = len(np.unique(weight.numpy()))
577+
assert nr_unique_weights == self.nr_of_clusters
578+
537579
if __name__ == "__main__":
538580
test.main()

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# ==============================================================================
1515
"""Keras ClusterWeights wrapper API."""
1616

17+
from operator import attrgetter
18+
1719
import tensorflow as tf
1820

1921
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
@@ -388,3 +390,32 @@ def set_weight_to_layer(self, weight_name, new_weight):
388390
return setattr(return_layer_cell, weight_name_no_index, new_weight)
389391
else:
390392
raise ValueError('No cells in the RNN layer to set weights for.')
393+
394+
395+
class ClusterWeightsMHA(ClusterWeights):
396+
"""This wrapper augments a keras MHA layer so that the weights can be clustered."""
397+
398+
def get_weight_from_layer(self, weight_name):
399+
pre, _, post = weight_name.rpartition('.')
400+
return getattr(getattr(self.layer, pre), post)
401+
402+
def set_weight_to_layer(self, weight_name, new_weight):
403+
pre, _, post = weight_name.rpartition('.')
404+
layer = attrgetter(pre)(self.layer)
405+
setattr(layer, post, new_weight)
406+
407+
def strip_clustering(self):
408+
""" The restore from config is not working for MHA layer, because
409+
weights are not created when the build function is called. Therefore,
410+
original weights have been replaced in the layer."""
411+
for weight_name, original_weight in self.original_clusterable_weights.items():
412+
413+
# Get the clustered weights
414+
clustered_weight = self.get_weight_from_layer(weight_name)
415+
416+
# Re-assign these weights to the original
417+
original_weight.assign(clustered_weight)
418+
setattr(self.layer, weight_name, original_weight)
419+
420+
return self.layer
421+

tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ class ClusteringRegistry(object):
9292
layers.Bidirectional,
9393
])
9494

95+
_SUPPORTED_MHA_LAYERS = {
96+
tf.keras.layers.MultiHeadAttention,
97+
}
98+
9599
@classmethod
96100
def supports(cls, layer):
97101
"""Returns whether the registry supports this layer type.
@@ -121,6 +125,9 @@ def supports(cls, layer):
121125
return False
122126
return True
123127

128+
if layer.__class__ in cls._SUPPORTED_MHA_LAYERS:
129+
return True
130+
124131
return False
125132

126133
def _get_rnn_cells(rnn_layer): # pylint: disable=no-self-argument
@@ -190,8 +197,16 @@ def get_clusterable_weights_rnn_cell(cell, i):
190197
clusterable_weights = get_clusterable_weights_rnn_cell(rnn_cell, 0)
191198
return clusterable_weights
192199

200+
def get_clusterable_weights_mha(): # pylint: disable=missing-docstring
201+
return [('_query_dense.kernel', layer._query_dense.kernel),
202+
('_key_dense.kernel', layer._key_dense.kernel),
203+
('_value_dense.kernel', layer._value_dense.kernel),
204+
('_output_dense.kernel', layer._output_dense.kernel)]
205+
193206
if layer.__class__ in cls._SUPPORTED_RNN_LAYERS:
194207
layer.get_clusterable_weights = get_clusterable_weights_rnn
208+
elif layer.__class__ in cls._SUPPORTED_MHA_LAYERS:
209+
layer.get_clusterable_weights = get_clusterable_weights_mha
195210
else:
196211
layer.get_clusterable_weights = get_clusterable_weights
197212

tensorflow_model_optimization/python/core/clustering/keras/clustering_registry_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,13 @@ def testMakeClusterableRaisesErrorOnRNNLayersUnsupportedCell(self):
512512
with self.assertRaises(ValueError):
513513
ClusterRegistry.make_clusterable(layer)
514514

515+
def testSupportsMultiHeadAttentionLayer(self):
516+
"""
517+
Verifies that ClusterRegistry supports a MultiHeadAttention layer.
518+
"""
519+
layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
520+
self.assertTrue(ClusterRegistry.supports(layer))
521+
ClusterRegistry.make_clusterable(layer)
515522

516523
if __name__ == '__main__':
517524
test.main()

0 commit comments

Comments
 (0)