Skip to content

Commit 1e247de

Browse files
committed
Added example for clustering of MHA.
Change-Id: I163333ed3e7d4c45383c2b90b56bfa27368f7999
1 parent 4e660f9 commit 1e247de

File tree

3 files changed

+109
-1
lines changed

3 files changed

+109
-1
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,9 @@ def _strip_clustering_wrapper(layer):
310310
layer, input_tensors=None, clone_function=_strip_clustering_wrapper)
311311

312312
elif isinstance(layer, cluster_wrapper.ClusterWeightsMHA):
313+
# Update cluster associations in order to get the latest weights
314+
layer.update_clustered_weights_associations()
315+
313316
# In case of MHA layer, use the overloaded implementation
314317
return layer.strip_clustering()
315318

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,6 @@ def testMHA(self):
565565
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
566566
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
567567
metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')])
568-
clustered_model.run_eagerly = True
569568
clustered_model.fit(self.x_train, self.y_train, epochs=1, batch_size=100, verbose=1)
570569

571570
stripped_model = cluster.strip_clustering(clustered_model)
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
# pylint: disable=missing-docstring
16+
"""Train a simple convnet with MultiHeadAttention layer on MNIST dataset
17+
and cluster it.
18+
"""
19+
import tensorflow as tf
20+
import tensorflow_model_optimization as tfmot
21+
22+
import numpy as np
23+
24+
NUMBER_OF_CLUSTERS = 3
25+
26+
# Load MNIST dataset
27+
mnist = tf.keras.datasets.mnist
28+
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
29+
30+
# Normalize the input image so that each pixel value is between 0 to 1.
31+
train_images = train_images / 255.0
32+
test_images = test_images / 255.0
33+
34+
# define model
35+
input = tf.keras.layers.Input(shape=(28, 28))
36+
x = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=16, name="mha")(
37+
query=input, value=input
38+
)
39+
x = tf.keras.layers.Flatten()(x)
40+
out = tf.keras.layers.Dense(10)(x)
41+
model = tf.keras.Model(inputs=input, outputs=out)
42+
43+
# Train the digit classification model
44+
model.compile(
45+
optimizer="adam",
46+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
47+
metrics=["accuracy"],
48+
)
49+
50+
model.fit(
51+
train_images, train_labels, epochs=1, validation_split=0.1,
52+
)
53+
54+
score = model.evaluate(test_images, test_labels, verbose=0)
55+
print('Model test loss:', score[0])
56+
print('Model test accuracy:', score[1])
57+
58+
# Compute end step to finish pruning after 2 epochs.
59+
batch_size = 128
60+
epochs = 1
61+
validation_split = 0.1 # 10% of training set will be used for validation set.
62+
63+
# Define model for clustering
64+
cluster_weights = tfmot.clustering.keras.cluster_weights
65+
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
66+
67+
clustering_params = {
68+
"number_of_clusters": NUMBER_OF_CLUSTERS,
69+
"cluster_centroids_init": CentroidInitialization.KMEANS_PLUS_PLUS,
70+
}
71+
model_for_clustering = cluster_weights(model, **clustering_params)
72+
73+
# `cluster_weights` requires a recompile.
74+
model_for_clustering.compile(
75+
optimizer="adam",
76+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
77+
metrics=["accuracy"],
78+
)
79+
80+
model_for_clustering.fit(
81+
train_images,
82+
train_labels,
83+
batch_size=batch_size,
84+
epochs=epochs,
85+
validation_split=validation_split,
86+
)
87+
88+
score = model_for_clustering.evaluate(test_images, test_labels, verbose=0)
89+
print('Clustered model test loss:', score[0])
90+
print('Clustered model test accuracy:', score[1])
91+
92+
# Strip clustering from the model
93+
clustered_model = tfmot.clustering.keras.strip_clustering(model_for_clustering)
94+
clustered_model.compile(
95+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
96+
optimizer='adam',
97+
metrics=['accuracy'])
98+
99+
score = clustered_model.evaluate(test_images, test_labels, verbose=0)
100+
print('Stripped clustered model test loss:', score[0])
101+
print('Stripped clustered model test accuracy:', score[1])
102+
103+
# Check that numbers of weights for MHA layer is the given number of clusters.
104+
mha_weights = list(filter(lambda x: 'mha' in x.name and 'kernel' in x.name, clustered_model.weights))
105+
for x in mha_weights:
106+
assert len(np.unique(x.numpy())) == NUMBER_OF_CLUSTERS

0 commit comments

Comments
 (0)