Skip to content

Commit 22f13b3

Browse files
committed
[Clustering] Enable RNN layers
[Clustering] Enable RNN layers Change-Id: Ic44fe625f8e927bceb06a8eb19f3492b486655ef
1 parent 465281b commit 22f13b3

File tree

9 files changed

+475
-84
lines changed

9 files changed

+475
-84
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper
2020
from tensorflow_model_optimization.python.core.clustering.keras import clustering_centroids
21+
from tensorflow_model_optimization.python.core.clustering.keras.clustering_registry import ClusteringRegistry
2122

2223
k = tf.keras.backend
2324
CustomObjectScope = tf.keras.utils.CustomObjectScope
@@ -236,6 +237,14 @@ def _add_clustering_wrapper(layer):
236237
return layer
237238
if isinstance(layer, InputLayer):
238239
return layer.__class__.from_config(layer.get_config())
240+
if isinstance(layer, tf.keras.layers.RNN):
241+
return cluster_wrapper.ClusterWeightsRNN(
242+
layer,
243+
number_of_clusters,
244+
cluster_centroids_init,
245+
preserve_sparsity,
246+
**kwargs,
247+
)
239248

240249
return cluster_wrapper.ClusterWeights(layer, number_of_clusters,
241250
cluster_centroids_init,
@@ -303,7 +312,7 @@ def _strip_clustering_wrapper(layer):
303312
for (position_variable,
304313
weight_name) in layer.position_original_weights.items():
305314
# Add the clustered weights at the correct position
306-
clustered_weight = getattr(layer.layer, weight_name)
315+
clustered_weight = layer.get_weight_from_layer(weight_name)
307316
updated_weights.insert(position_variable, clustered_weight)
308317

309318
# Construct a clean layer with the updated weights

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,5 +344,152 @@ def on_train_batch_end(self, batch, logs=None):
344344
callbacks=[CheckWeightsCallback()])
345345

346346

347+
class ClusterRNNIntegrationTest(tf.test.TestCase, parameterized.TestCase):
348+
"""Integration tests for clustering RNN layers."""
349+
350+
def setUp(self):
351+
self.max_features = 10
352+
self.maxlen = 2
353+
self.batch_size = 32
354+
self.x_train = np.random.random((64, self.maxlen))
355+
self.y_train = np.random.randint(0, 2, (64,))
356+
357+
self.params_clustering = {
358+
"number_of_clusters": 16,
359+
"cluster_centroids_init": CentroidInitialization.KMEANS_PLUS_PLUS,
360+
}
361+
362+
def _train(self, model):
363+
model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["accuracy"])
364+
model.fit(
365+
self.x_train,
366+
self.y_train,
367+
batch_size=self.batch_size,
368+
epochs=1,
369+
)
370+
371+
def _clusterTrainStrip(self, model):
372+
clustered_model = cluster.cluster_weights(
373+
model,
374+
**self.params_clustering,
375+
)
376+
self._train(clustered_model)
377+
stripped_model = cluster.strip_clustering(clustered_model)
378+
379+
return stripped_model
380+
381+
def __assertNbUniqueWeights(self, weight, expected_unique_weights):
382+
nr_unique_weights = len(np.unique(weight.numpy().flatten()))
383+
assert nr_unique_weights == expected_unique_weights
384+
385+
@keras_parameterized.run_all_keras_modes
386+
def testClusterSimpleRNN(self):
387+
model = keras.models.Sequential()
388+
model.add(keras.layers.Embedding(self.max_features, 16, input_length=self.maxlen))
389+
model.add(keras.layers.SimpleRNN(16, return_sequences=True))
390+
model.add(keras.layers.SimpleRNN(16))
391+
model.add(keras.layers.Dense(1))
392+
model.add(keras.layers.Activation("sigmoid"))
393+
394+
stripped_model = self._clusterTrainStrip(model)
395+
396+
self.__assertNbUniqueWeights(
397+
weight=stripped_model.layers[1].cell.kernel,
398+
expected_unique_weights=self.params_clustering["number_of_clusters"],
399+
)
400+
self.__assertNbUniqueWeights(
401+
weight=stripped_model.layers[1].cell.recurrent_kernel,
402+
expected_unique_weights=self.params_clustering["number_of_clusters"],
403+
)
404+
405+
self._train(stripped_model)
406+
407+
@keras_parameterized.run_all_keras_modes
408+
def testClusterLSTM(self):
409+
model = keras.models.Sequential()
410+
model.add(keras.layers.Embedding(self.max_features, 16, input_length=self.maxlen))
411+
model.add(keras.layers.LSTM(16, return_sequences=True))
412+
model.add(keras.layers.LSTM(16))
413+
model.add(keras.layers.Dense(1))
414+
model.add(keras.layers.Activation("sigmoid"))
415+
416+
stripped_model = self._clusterTrainStrip(model)
417+
418+
self.__assertNbUniqueWeights(
419+
weight=stripped_model.layers[1].cell.kernel,
420+
expected_unique_weights=self.params_clustering["number_of_clusters"],
421+
)
422+
self.__assertNbUniqueWeights(
423+
weight=stripped_model.layers[1].cell.recurrent_kernel,
424+
expected_unique_weights=self.params_clustering["number_of_clusters"],
425+
)
426+
427+
self._train(stripped_model)
428+
429+
@keras_parameterized.run_all_keras_modes
430+
def testClusterGRU(self):
431+
model = keras.models.Sequential()
432+
model.add(keras.layers.Embedding(self.max_features, 16, input_length=self.maxlen))
433+
model.add(keras.layers.GRU(16, return_sequences=True))
434+
model.add(keras.layers.GRU(16))
435+
model.add(keras.layers.Dense(1))
436+
model.add(keras.layers.Activation("sigmoid"))
437+
438+
stripped_model = self._clusterTrainStrip(model)
439+
440+
self.__assertNbUniqueWeights(
441+
weight=stripped_model.layers[1].cell.kernel,
442+
expected_unique_weights=self.params_clustering["number_of_clusters"],
443+
)
444+
self.__assertNbUniqueWeights(
445+
weight=stripped_model.layers[1].cell.recurrent_kernel,
446+
expected_unique_weights=self.params_clustering["number_of_clusters"],
447+
)
448+
449+
self._train(stripped_model)
450+
451+
@keras_parameterized.run_all_keras_modes
452+
def testClusterPeepholeLSTM(self):
453+
model = keras.models.Sequential()
454+
model.add(keras.layers.Embedding(self.max_features, 16, input_length=self.maxlen))
455+
model.add(keras.layers.RNN(tf.keras.experimental.PeepholeLSTMCell(16)))
456+
model.add(keras.layers.Dense(1))
457+
model.add(keras.layers.Activation("sigmoid"))
458+
459+
# PeepholeLSTM not supported yet.
460+
with self.assertRaises(ValueError):
461+
self._clusterTrainStrip(model)
462+
463+
@keras_parameterized.run_all_keras_modes
464+
def testClusterBidirectional(self):
465+
model = keras.models.Sequential()
466+
model.add(keras.layers.Embedding(self.max_features, 16, input_length=self.maxlen))
467+
model.add(keras.layers.Bidirectional(keras.layers.SimpleRNN(16)))
468+
model.add(keras.layers.Dense(1))
469+
model.add(keras.layers.Activation("sigmoid"))
470+
471+
# Bidirectional not supported yet.
472+
with self.assertRaises(ValueError):
473+
self._clusterTrainStrip(model)
474+
475+
@keras_parameterized.run_all_keras_modes
476+
def testClusterStackedRNNCells(self):
477+
model = keras.models.Sequential()
478+
model.add(keras.layers.Embedding(self.max_features, 16, input_length=self.maxlen))
479+
model.add(
480+
tf.keras.layers.RNN(
481+
tf.keras.layers.StackedRNNCells(
482+
[keras.layers.SimpleRNNCell(16) for _ in range(2)]
483+
)
484+
)
485+
)
486+
model.add(keras.layers.Dense(1))
487+
model.add(keras.layers.Activation("sigmoid"))
488+
489+
# StackedRNNCells not supported yet.
490+
with self.assertRaises(ValueError):
491+
self._clusterTrainStrip(model)
492+
493+
347494
if __name__ == "__main__":
348495
test.main()

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,12 @@ def __init__(self,
139139
def _make_layer_name(layer):
140140
return '{}_{}'.format('cluster', layer.name)
141141

142+
def get_weight_from_layer(self, weight_name):
143+
return getattr(self.layer, weight_name)
144+
145+
def set_weight_to_layer(self, weight_name, new_weight):
146+
setattr(self.layer, weight_name, new_weight)
147+
142148
def build(self, input_shape):
143149
super(ClusterWeights, self).build(input_shape)
144150
self.build_input_shape = input_shape
@@ -148,7 +154,7 @@ def build(self, input_shape):
148154
# Store the original weight in this wrapper
149155
# The child reference will be overridden in
150156
# update_clustered_weights_associations
151-
original_weight = getattr(self.layer, weight_name)
157+
original_weight = self.get_weight_from_layer(weight_name)
152158
self.original_clusterable_weights[weight_name] = original_weight
153159
setattr(self, 'original_weight_' + weight_name,
154160
original_weight) # Track the variable
@@ -220,7 +226,7 @@ def update_clustered_weights_associations(self):
220226
self.sparsity_masks[weight_name])
221227

222228
# Replace the weights with their clustered counterparts
223-
setattr(self.layer, weight_name, clustered_weights)
229+
self.set_weight_to_layer(weight_name, clustered_weights)
224230

225231
def call(self, inputs, training=None, **kwargs):
226232
# Update cluster associations in order to set the latest weights
@@ -292,3 +298,13 @@ def get_weights(self):
292298

293299
def set_weights(self, weights):
294300
self.layer.set_weights(weights)
301+
302+
303+
class ClusterWeightsRNN(ClusterWeights):
304+
"""This wrapper augments a keras RNN layer so that the weights can be clustered."""
305+
306+
def get_weight_from_layer(self, weight_name):
307+
return getattr(self.layer.cell, weight_name)
308+
309+
def set_weight_to_layer(self, weight_name, new_weight):
310+
setattr(self.layer.cell, weight_name, new_weight)

tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py

Lines changed: 22 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,18 @@ class ClusteringLookupRegistry(object):
133133
'kernel': ConvolutionalWeightsCA,
134134
'bias': BiasWeightsCA
135135
},
136+
layers.LSTM: {
137+
'kernel': DenseWeightsCA,
138+
'recurrent_kernel': DenseWeightsCA
139+
},
140+
layers.GRU: {
141+
'kernel': DenseWeightsCA,
142+
'recurrent_kernel': DenseWeightsCA
143+
},
144+
layers.SimpleRNN: {
145+
'kernel': DenseWeightsCA,
146+
'recurrent_kernel': DenseWeightsCA
147+
},
136148
}
137149

138150
@classmethod
@@ -235,39 +247,12 @@ class ClusteringRegistry(object):
235247
layers.LayerNormalization: [],
236248
}
237249

238-
_RNN_CELLS_WEIGHTS_MAP = {
239-
# NOTE: RNN cells are added via compat.v1 and compat.v2 to support legacy
240-
# TensorFlow 2.X behavior where the v2 RNN uses the v1 RNNCell instead of
241-
# the v2 RNNCell.
242-
tf.compat.v1.keras.layers.GRUCell: ['kernel', 'recurrent_kernel'],
243-
tf.compat.v2.keras.layers.GRUCell: ['kernel', 'recurrent_kernel'],
244-
tf.compat.v1.keras.layers.LSTMCell: ['kernel', 'recurrent_kernel'],
245-
tf.compat.v2.keras.layers.LSTMCell: ['kernel', 'recurrent_kernel'],
246-
tf.compat.v1.keras.experimental.PeepholeLSTMCell: [
247-
'kernel', 'recurrent_kernel'
248-
],
249-
tf.compat.v2.keras.experimental.PeepholeLSTMCell: [
250-
'kernel', 'recurrent_kernel'
251-
],
252-
tf.compat.v1.keras.layers.SimpleRNNCell: ['kernel', 'recurrent_kernel'],
253-
tf.compat.v2.keras.layers.SimpleRNNCell: ['kernel', 'recurrent_kernel'],
254-
}
255-
256-
_RNN_LAYERS = {
250+
_SUPPORTED_RNN_LAYERS = {
257251
layers.GRU,
258252
layers.LSTM,
259-
layers.RNN,
260253
layers.SimpleRNN,
261254
}
262255

263-
_RNN_CELLS_STR = ', '.join(str(_RNN_CELLS_WEIGHTS_MAP.keys()))
264-
265-
_RNN_CELL_ERROR_MSG = (
266-
'RNN Layer {} contains cell type {} which is either not supported or does'
267-
'not inherit ClusterableLayer. The cell must be one of {}, or implement '
268-
'ClusterableLayer.'
269-
)
270-
271256
@classmethod
272257
def supports(cls, layer):
273258
"""Returns whether the registry supports this layer type.
@@ -281,31 +266,17 @@ def supports(cls, layer):
281266
"""
282267
# Automatically enable layers with zero trainable weights.
283268
# Example: Reshape, AveragePooling2D, Maximum/Minimum, etc.
284-
if not layer.trainable_weights:
269+
if not layer.trainable_weights and not isinstance(layer, layers.RNN):
285270
return True
286271

287272
if layer.__class__ in cls._LAYERS_WEIGHTS_MAP:
288273
return True
289274

290-
if layer.__class__ in cls._RNN_LAYERS:
291-
for cell in cls._get_rnn_cells(layer):
292-
if (cell.__class__ not in cls._RNN_CELLS_WEIGHTS_MAP
293-
and not isinstance(cell, clusterable_layer.ClusterableLayer)):
294-
return False
275+
if layer.__class__ in cls._SUPPORTED_RNN_LAYERS:
295276
return True
296277

297278
return False
298279

299-
@staticmethod
300-
def _get_rnn_cells(rnn_layer):
301-
if isinstance(rnn_layer.cell, layers.StackedRNNCells):
302-
return rnn_layer.cell.cells
303-
return [rnn_layer.cell]
304-
305-
@classmethod
306-
def _is_rnn_layer(cls, layer):
307-
return layer.__class__ in cls._RNN_LAYERS
308-
309280
@classmethod
310281
def _weight_names(cls, layer):
311282
# For layers with zero trainable weights, like Reshape, Pooling.
@@ -323,7 +294,6 @@ def make_clusterable(cls, layer):
323294
324295
Returns:
325296
The modified layer object.
326-
327297
"""
328298

329299
if not cls.supports(layer):
@@ -334,23 +304,16 @@ def get_clusterable_weights():
334304
for weight_name in cls._weight_names(layer)]
335305

336306
def get_clusterable_weights_rnn(): # pylint: disable=missing-docstring
337-
def get_clusterable_weights_rnn_cell(cell):
338-
if cell.__class__ in cls._RNN_CELLS_WEIGHTS_MAP:
339-
return [(weight, getattr(cell, weight))
340-
for weight in cls._RNN_CELLS_WEIGHTS_MAP[cell.__class__]]
341-
342-
if isinstance(cell, clusterable_layer.ClusterableLayer):
343-
return cell.get_clusterable_weights()
344-
345-
raise ValueError(cls._RNN_CELL_ERROR_MSG.format(
346-
layer.__class__, cell.__class__, cls._RNN_CELLS_WEIGHTS_MAP.keys()))
307+
if isinstance(layer.cell, clusterable_layer.ClusterableLayer):
308+
raise ValueError("ClusterableLayer is not yet supported for RNNs based layer.")
347309

348-
clusterable_weights = []
349-
for rnn_cell in cls._get_rnn_cells(layer):
350-
clusterable_weights.extend(get_clusterable_weights_rnn_cell(rnn_cell))
310+
clusterable_weights = [
311+
('kernel', layer.cell.kernel),
312+
('recurrent_kernel', layer.cell.recurrent_kernel),
313+
]
351314
return clusterable_weights
352315

353-
if cls._is_rnn_layer(layer):
316+
if layer.__class__ in cls._SUPPORTED_RNN_LAYERS:
354317
layer.get_clusterable_weights = get_clusterable_weights_rnn
355318
else:
356319
layer.get_clusterable_weights = get_clusterable_weights

0 commit comments

Comments
 (0)