Skip to content

Commit 8211578

Browse files
Merge pull request #686 from johan-gras:feature/clustering-rnn
PiperOrigin-RevId: 375626697
2 parents 6be78ae + 22f13b3 commit 8211578

File tree

9 files changed

+492
-86
lines changed

9 files changed

+492
-86
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,14 @@ def _add_clustering_wrapper(layer):
236236
return layer
237237
if isinstance(layer, InputLayer):
238238
return layer.__class__.from_config(layer.get_config())
239+
if isinstance(layer, tf.keras.layers.RNN):
240+
return cluster_wrapper.ClusterWeightsRNN(
241+
layer,
242+
number_of_clusters,
243+
cluster_centroids_init,
244+
preserve_sparsity,
245+
**kwargs,
246+
)
239247

240248
return cluster_wrapper.ClusterWeights(layer, number_of_clusters,
241249
cluster_centroids_init,
@@ -303,7 +311,7 @@ def _strip_clustering_wrapper(layer):
303311
for (position_variable,
304312
weight_name) in layer.position_original_weights.items():
305313
# Add the clustered weights at the correct position
306-
clustered_weight = getattr(layer.layer, weight_name)
314+
clustered_weight = layer.get_weight_from_layer(weight_name)
307315
updated_weights.insert(position_variable, clustered_weight)
308316

309317
# Construct a clean layer with the updated weights

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,5 +344,160 @@ def on_train_batch_end(self, batch, logs=None):
344344
callbacks=[CheckWeightsCallback()])
345345

346346

347+
class ClusterRNNIntegrationTest(tf.test.TestCase, parameterized.TestCase):
348+
"""Integration tests for clustering RNN layers."""
349+
350+
def setUp(self):
351+
super(ClusterRNNIntegrationTest, self).setUp()
352+
self.max_features = 10
353+
self.maxlen = 2
354+
self.batch_size = 32
355+
self.x_train = np.random.random((64, self.maxlen))
356+
self.y_train = np.random.randint(0, 2, (64,))
357+
358+
self.params_clustering = {
359+
"number_of_clusters": 16,
360+
"cluster_centroids_init": CentroidInitialization.KMEANS_PLUS_PLUS,
361+
}
362+
363+
def _train(self, model):
364+
model.compile(loss="binary_crossentropy", optimizer="adam",
365+
metrics=["accuracy"])
366+
model.fit(
367+
self.x_train,
368+
self.y_train,
369+
batch_size=self.batch_size,
370+
epochs=1,
371+
)
372+
373+
def _clusterTrainStrip(self, model):
374+
clustered_model = cluster.cluster_weights(
375+
model,
376+
**self.params_clustering,
377+
)
378+
self._train(clustered_model)
379+
stripped_model = cluster.strip_clustering(clustered_model)
380+
381+
return stripped_model
382+
383+
def __assertNbUniqueWeights(self, weight, expected_unique_weights):
384+
nr_unique_weights = len(np.unique(weight.numpy().flatten()))
385+
assert nr_unique_weights == expected_unique_weights
386+
387+
@keras_parameterized.run_all_keras_modes
388+
def testClusterSimpleRNN(self):
389+
model = keras.models.Sequential()
390+
model.add(keras.layers.Embedding(self.max_features, 16,
391+
input_length=self.maxlen))
392+
model.add(keras.layers.SimpleRNN(16, return_sequences=True))
393+
model.add(keras.layers.SimpleRNN(16))
394+
model.add(keras.layers.Dense(1))
395+
model.add(keras.layers.Activation("sigmoid"))
396+
397+
stripped_model = self._clusterTrainStrip(model)
398+
399+
self.__assertNbUniqueWeights(
400+
weight=stripped_model.layers[1].cell.kernel,
401+
expected_unique_weights=self.params_clustering["number_of_clusters"],
402+
)
403+
self.__assertNbUniqueWeights(
404+
weight=stripped_model.layers[1].cell.recurrent_kernel,
405+
expected_unique_weights=self.params_clustering["number_of_clusters"],
406+
)
407+
408+
self._train(stripped_model)
409+
410+
@keras_parameterized.run_all_keras_modes
411+
def testClusterLSTM(self):
412+
model = keras.models.Sequential()
413+
model.add(keras.layers.Embedding(self.max_features, 16,
414+
input_length=self.maxlen))
415+
model.add(keras.layers.LSTM(16, return_sequences=True))
416+
model.add(keras.layers.LSTM(16))
417+
model.add(keras.layers.Dense(1))
418+
model.add(keras.layers.Activation("sigmoid"))
419+
420+
stripped_model = self._clusterTrainStrip(model)
421+
422+
self.__assertNbUniqueWeights(
423+
weight=stripped_model.layers[1].cell.kernel,
424+
expected_unique_weights=self.params_clustering["number_of_clusters"],
425+
)
426+
self.__assertNbUniqueWeights(
427+
weight=stripped_model.layers[1].cell.recurrent_kernel,
428+
expected_unique_weights=self.params_clustering["number_of_clusters"],
429+
)
430+
431+
self._train(stripped_model)
432+
433+
@keras_parameterized.run_all_keras_modes
434+
def testClusterGRU(self):
435+
model = keras.models.Sequential()
436+
model.add(keras.layers.Embedding(self.max_features, 16,
437+
input_length=self.maxlen))
438+
model.add(keras.layers.GRU(16, return_sequences=True))
439+
model.add(keras.layers.GRU(16))
440+
model.add(keras.layers.Dense(1))
441+
model.add(keras.layers.Activation("sigmoid"))
442+
443+
stripped_model = self._clusterTrainStrip(model)
444+
445+
self.__assertNbUniqueWeights(
446+
weight=stripped_model.layers[1].cell.kernel,
447+
expected_unique_weights=self.params_clustering["number_of_clusters"],
448+
)
449+
self.__assertNbUniqueWeights(
450+
weight=stripped_model.layers[1].cell.recurrent_kernel,
451+
expected_unique_weights=self.params_clustering["number_of_clusters"],
452+
)
453+
454+
self._train(stripped_model)
455+
456+
@keras_parameterized.run_all_keras_modes
457+
def testClusterPeepholeLSTM(self):
458+
model = keras.models.Sequential()
459+
model.add(keras.layers.Embedding(self.max_features, 16,
460+
input_length=self.maxlen))
461+
model.add(keras.layers.RNN(tf.keras.experimental.PeepholeLSTMCell(16)))
462+
model.add(keras.layers.Dense(1))
463+
model.add(keras.layers.Activation("sigmoid"))
464+
465+
# PeepholeLSTM not supported yet.
466+
with self.assertRaises(ValueError):
467+
self._clusterTrainStrip(model)
468+
469+
@keras_parameterized.run_all_keras_modes
470+
def testClusterBidirectional(self):
471+
model = keras.models.Sequential()
472+
model.add(keras.layers.Embedding(self.max_features, 16,
473+
input_length=self.maxlen))
474+
model.add(keras.layers.Bidirectional(keras.layers.SimpleRNN(16)))
475+
model.add(keras.layers.Dense(1))
476+
model.add(keras.layers.Activation("sigmoid"))
477+
478+
# Bidirectional not supported yet.
479+
with self.assertRaises(ValueError):
480+
self._clusterTrainStrip(model)
481+
482+
@keras_parameterized.run_all_keras_modes
483+
def testClusterStackedRNNCells(self):
484+
model = keras.models.Sequential()
485+
model.add(keras.layers.Embedding(self.max_features, 16,
486+
input_length=self.maxlen))
487+
model.add(
488+
tf.keras.layers.RNN(
489+
tf.keras.layers.StackedRNNCells(
490+
[keras.layers.SimpleRNNCell(16) for _ in range(2)]
491+
)
492+
)
493+
)
494+
model.add(keras.layers.Dense(1))
495+
model.add(keras.layers.Activation("sigmoid"))
496+
497+
# StackedRNNCells not supported yet.
498+
with self.assertRaises(ValueError):
499+
self._clusterTrainStrip(model)
500+
501+
347502
if __name__ == "__main__":
348503
test.main()

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,12 @@ def __init__(self,
139139
def _make_layer_name(layer):
140140
return '{}_{}'.format('cluster', layer.name)
141141

142+
def get_weight_from_layer(self, weight_name):
143+
return getattr(self.layer, weight_name)
144+
145+
def set_weight_to_layer(self, weight_name, new_weight):
146+
setattr(self.layer, weight_name, new_weight)
147+
142148
def build(self, input_shape):
143149
super(ClusterWeights, self).build(input_shape)
144150
self.build_input_shape = input_shape
@@ -148,7 +154,7 @@ def build(self, input_shape):
148154
# Store the original weight in this wrapper
149155
# The child reference will be overridden in
150156
# update_clustered_weights_associations
151-
original_weight = getattr(self.layer, weight_name)
157+
original_weight = self.get_weight_from_layer(weight_name)
152158
self.original_clusterable_weights[weight_name] = original_weight
153159
setattr(self, 'original_weight_' + weight_name,
154160
original_weight) # Track the variable
@@ -220,7 +226,7 @@ def update_clustered_weights_associations(self):
220226
self.sparsity_masks[weight_name])
221227

222228
# Replace the weights with their clustered counterparts
223-
setattr(self.layer, weight_name, clustered_weights)
229+
self.set_weight_to_layer(weight_name, clustered_weights)
224230

225231
def call(self, inputs, training=None, **kwargs):
226232
# Update cluster associations in order to set the latest weights
@@ -292,3 +298,13 @@ def get_weights(self):
292298

293299
def set_weights(self, weights):
294300
self.layer.set_weights(weights)
301+
302+
303+
class ClusterWeightsRNN(ClusterWeights):
304+
"""This wrapper augments a keras RNN layer so that the weights can be clustered."""
305+
306+
def get_weight_from_layer(self, weight_name):
307+
return getattr(self.layer.cell, weight_name)
308+
309+
def set_weight_to_layer(self, weight_name, new_weight):
310+
setattr(self.layer.cell, weight_name, new_weight)

tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py

Lines changed: 25 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,18 @@ class ClusteringLookupRegistry(object):
133133
'kernel': ConvolutionalWeightsCA,
134134
'bias': BiasWeightsCA
135135
},
136+
layers.LSTM: {
137+
'kernel': DenseWeightsCA,
138+
'recurrent_kernel': DenseWeightsCA
139+
},
140+
layers.GRU: {
141+
'kernel': DenseWeightsCA,
142+
'recurrent_kernel': DenseWeightsCA
143+
},
144+
layers.SimpleRNN: {
145+
'kernel': DenseWeightsCA,
146+
'recurrent_kernel': DenseWeightsCA
147+
},
136148
}
137149

138150
@classmethod
@@ -235,38 +247,11 @@ class ClusteringRegistry(object):
235247
layers.LayerNormalization: [],
236248
}
237249

238-
_RNN_CELLS_WEIGHTS_MAP = {
239-
# NOTE: RNN cells are added via compat.v1 and compat.v2 to support legacy
240-
# TensorFlow 2.X behavior where the v2 RNN uses the v1 RNNCell instead of
241-
# the v2 RNNCell.
242-
tf.compat.v1.keras.layers.GRUCell: ['kernel', 'recurrent_kernel'],
243-
tf.compat.v2.keras.layers.GRUCell: ['kernel', 'recurrent_kernel'],
244-
tf.compat.v1.keras.layers.LSTMCell: ['kernel', 'recurrent_kernel'],
245-
tf.compat.v2.keras.layers.LSTMCell: ['kernel', 'recurrent_kernel'],
246-
tf.compat.v1.keras.experimental.PeepholeLSTMCell: [
247-
'kernel', 'recurrent_kernel'
248-
],
249-
tf.compat.v2.keras.experimental.PeepholeLSTMCell: [
250-
'kernel', 'recurrent_kernel'
251-
],
252-
tf.compat.v1.keras.layers.SimpleRNNCell: ['kernel', 'recurrent_kernel'],
253-
tf.compat.v2.keras.layers.SimpleRNNCell: ['kernel', 'recurrent_kernel'],
254-
}
255-
256-
_RNN_LAYERS = {
250+
_SUPPORTED_RNN_LAYERS = frozenset([
257251
layers.GRU,
258252
layers.LSTM,
259-
layers.RNN,
260253
layers.SimpleRNN,
261-
}
262-
263-
_RNN_CELLS_STR = ', '.join(str(_RNN_CELLS_WEIGHTS_MAP.keys()))
264-
265-
_RNN_CELL_ERROR_MSG = (
266-
'RNN Layer {} contains cell type {} which is either not supported or does'
267-
'not inherit ClusterableLayer. The cell must be one of {}, or implement '
268-
'ClusterableLayer.'
269-
)
254+
])
270255

271256
@classmethod
272257
def supports(cls, layer):
@@ -281,31 +266,17 @@ def supports(cls, layer):
281266
"""
282267
# Automatically enable layers with zero trainable weights.
283268
# Example: Reshape, AveragePooling2D, Maximum/Minimum, etc.
284-
if not layer.trainable_weights:
269+
if not layer.trainable_weights and not isinstance(layer, layers.RNN):
285270
return True
286271

287272
if layer.__class__ in cls._LAYERS_WEIGHTS_MAP:
288273
return True
289274

290-
if layer.__class__ in cls._RNN_LAYERS:
291-
for cell in cls._get_rnn_cells(layer):
292-
if (cell.__class__ not in cls._RNN_CELLS_WEIGHTS_MAP
293-
and not isinstance(cell, clusterable_layer.ClusterableLayer)):
294-
return False
275+
if layer.__class__ in cls._SUPPORTED_RNN_LAYERS:
295276
return True
296277

297278
return False
298279

299-
@staticmethod
300-
def _get_rnn_cells(rnn_layer):
301-
if isinstance(rnn_layer.cell, layers.StackedRNNCells):
302-
return rnn_layer.cell.cells
303-
return [rnn_layer.cell]
304-
305-
@classmethod
306-
def _is_rnn_layer(cls, layer):
307-
return layer.__class__ in cls._RNN_LAYERS
308-
309280
@classmethod
310281
def _weight_names(cls, layer):
311282
# For layers with zero trainable weights, like Reshape, Pooling.
@@ -323,7 +294,6 @@ def make_clusterable(cls, layer):
323294
324295
Returns:
325296
The modified layer object.
326-
327297
"""
328298

329299
if not cls.supports(layer):
@@ -334,23 +304,17 @@ def get_clusterable_weights():
334304
for weight_name in cls._weight_names(layer)]
335305

336306
def get_clusterable_weights_rnn(): # pylint: disable=missing-docstring
337-
def get_clusterable_weights_rnn_cell(cell):
338-
if cell.__class__ in cls._RNN_CELLS_WEIGHTS_MAP:
339-
return [(weight, getattr(cell, weight))
340-
for weight in cls._RNN_CELLS_WEIGHTS_MAP[cell.__class__]]
341-
342-
if isinstance(cell, clusterable_layer.ClusterableLayer):
343-
return cell.get_clusterable_weights()
344-
345-
raise ValueError(cls._RNN_CELL_ERROR_MSG.format(
346-
layer.__class__, cell.__class__, cls._RNN_CELLS_WEIGHTS_MAP.keys()))
347-
348-
clusterable_weights = []
349-
for rnn_cell in cls._get_rnn_cells(layer):
350-
clusterable_weights.extend(get_clusterable_weights_rnn_cell(rnn_cell))
307+
if isinstance(layer.cell, clusterable_layer.ClusterableLayer):
308+
raise ValueError(
309+
'ClusterableLayer is not yet supported for RNNs based layer.')
310+
311+
clusterable_weights = [
312+
('kernel', layer.cell.kernel),
313+
('recurrent_kernel', layer.cell.recurrent_kernel),
314+
]
351315
return clusterable_weights
352316

353-
if cls._is_rnn_layer(layer):
317+
if layer.__class__ in cls._SUPPORTED_RNN_LAYERS:
354318
layer.get_clusterable_weights = get_clusterable_weights_rnn
355319
else:
356320
layer.get_clusterable_weights = get_clusterable_weights

0 commit comments

Comments
 (0)