Skip to content

Commit b04c6c4

Browse files
Merge pull request #870 from wwwind:mha_pruning
PiperOrigin-RevId: 411754405
2 parents 5456312 + 516c667 commit b04c6c4

File tree

3 files changed

+145
-0
lines changed

3 files changed

+145
-0
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/prune_integration_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ def _get_params_for_layer(self, layer_type):
105105
# Embedding has a separate test since training it is not
106106
# feasible as a single layer.
107107
layers.Embedding: (None, None),
108+
109+
# MultiHeadAttention layer has a separate test
110+
layers.MultiHeadAttention: (None, None)
108111
}[layer_type]
109112

110113
def setUp(self):
@@ -528,6 +531,28 @@ def testPruneRecursivelyReachesTargetSparsity(self):
528531
input_data = np.random.randint(10, size=(32, 10))
529532
self._check_strip_pruning_matches_original(model, 0.5, input_data)
530533

534+
def testMHALayerReachesTargetSparsity(self):
535+
inp = tf.keras.layers.Input(shape=(32,32), batch_size=100)
536+
x = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=16)(query=inp, value=inp)
537+
out = tf.keras.layers.Flatten()(x)
538+
model = tf.keras.Model(inputs=inp, outputs=out)
539+
model = prune.prune_low_magnitude(model, **self.params)
540+
x_train = np.random.uniform(size=(500, 32, 32))
541+
y_train = np.random.randint(low=0, high=1024, size=(500,))
542+
model.compile(
543+
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
544+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
545+
metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')])
546+
test_utils.assert_model_sparsity(self, 0.0, model)
547+
model.fit(
548+
x_train,
549+
y_train,
550+
epochs=1,
551+
batch_size=100,
552+
callbacks=[pruning_callbacks.UpdatePruningStep()])
553+
test_utils.assert_model_sparsity(self, 0.5, model)
554+
self._check_strip_pruning_matches_original(model, 0.5, x_train)
555+
531556
@parameterized.parameters(test_utils.model_type_keys())
532557
def testPrunesMnist_ReachesTargetSparsity(self, model_type):
533558
model = test_utils.build_mnist_model(model_type, self.params)

tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ class PruneRegistry(object):
9393
layers.MaxPooling1D: [],
9494
layers.MaxPooling2D: [],
9595
layers.MaxPooling3D: [],
96+
layers.MultiHeadAttention: [
97+
'_query_dense.kernel',
98+
'_key_dense.kernel',
99+
'_value_dense.kernel',
100+
'_output_dense.kernel'],
96101
layers.experimental.preprocessing.Rescaling.__class__: [],
97102
TensorFlowOpLayer: [],
98103
}
@@ -163,6 +168,10 @@ def _get_rnn_cells(cls, rnn_layer):
163168
def _is_rnn_layer(cls, layer):
164169
return layer.__class__ in cls._RNN_LAYERS
165170

171+
@classmethod
172+
def _is_mha_layer(cls, layer):
173+
return layer.__class__ is layers.MultiHeadAttention
174+
166175
@classmethod
167176
def _weight_names(cls, layer):
168177
return cls._LAYERS_WEIGHTS_MAP[layer.__class__]
@@ -202,8 +211,20 @@ def get_prunable_weights_rnn_cell(cell):
202211
prunable_weights.extend(get_prunable_weights_rnn_cell(rnn_cell))
203212
return prunable_weights
204213

214+
def get_prunable_weights_mha(): # pylint: disable=missing-docstring
215+
def get_prunable_weights_mha_weight(weight_name):
216+
pre, _, post = weight_name.rpartition('.')
217+
return getattr(getattr(layer, pre), post)
218+
219+
prunable_weights = []
220+
for weight_name in cls._weight_names(layer):
221+
prunable_weights.append(get_prunable_weights_mha_weight(weight_name))
222+
return prunable_weights
223+
205224
if cls._is_rnn_layer(layer):
206225
layer.get_prunable_weights = get_prunable_weights_rnn
226+
elif cls._is_mha_layer(layer):
227+
layer.get_prunable_weights = get_prunable_weights_mha
207228
else:
208229
layer.get_prunable_weights = get_prunable_weights
209230

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
# pylint: disable=missing-docstring
16+
"""Train a simple model with MultiHeadAttention layer on MNIST dataset
17+
and prune it.
18+
"""
19+
import tensorflow as tf
20+
21+
from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils
22+
from tensorflow_model_optimization.python.core.sparsity.keras import prune
23+
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks
24+
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
25+
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_utils
26+
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
27+
28+
tf.random.set_seed(42)
29+
30+
ConstantSparsity = pruning_schedule.ConstantSparsity
31+
32+
# Load MNIST dataset
33+
mnist = tf.keras.datasets.mnist
34+
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
35+
36+
# Normalize the input image so that each pixel value is between 0 to 1.
37+
train_images = train_images / 255.0
38+
test_images = test_images / 255.0
39+
40+
# define model
41+
input = tf.keras.layers.Input(shape=(28, 28))
42+
x = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=16, name='mha')(
43+
query=input, value=input
44+
)
45+
x = tf.keras.layers.Flatten()(x)
46+
out = tf.keras.layers.Dense(10)(x)
47+
model = tf.keras.Model(inputs=input, outputs=out)
48+
49+
# Train the digit classification model
50+
model.compile(
51+
optimizer='adam',
52+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
53+
metrics=['accuracy'],
54+
)
55+
56+
model.fit(
57+
train_images, train_labels, epochs=10, validation_split=0.1,
58+
)
59+
60+
score = model.evaluate(test_images, test_labels, verbose=0)
61+
print('Model test loss:', score[0])
62+
print('Model test accuracy:', score[1])
63+
64+
# Define parameters for pruning
65+
66+
batch_size = 128
67+
epochs = 3
68+
validation_split = 0.1 # 10% of training set will be used for validation set.
69+
70+
callbacks = [
71+
pruning_callbacks.UpdatePruningStep(),
72+
pruning_callbacks.PruningSummaries(log_dir='/tmp/logs')
73+
]
74+
75+
pruning_params = {
76+
'pruning_schedule': ConstantSparsity(0.75, begin_step=2000, frequency=100)
77+
}
78+
79+
model_for_pruning = prune.prune_low_magnitude(model, **pruning_params)
80+
81+
# `prune_low_magnitude` requires a recompile.
82+
model_for_pruning.compile(
83+
optimizer='adam',
84+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
85+
metrics=['accuracy'],
86+
)
87+
88+
model_for_pruning.fit(
89+
train_images,
90+
train_labels,
91+
batch_size=batch_size,
92+
epochs=epochs,
93+
callbacks=callbacks,
94+
validation_split=validation_split,
95+
)
96+
97+
score = model_for_pruning.evaluate(test_images, test_labels, verbose=0)
98+
print('Pruned model test loss:', score[0])
99+
print('Pruned model test accuracy:', score[1])

0 commit comments

Comments
 (0)