Skip to content

Commit 316e941

Browse files
Merge pull request #952 from wwwind:clustering_conv1d_transpose
PiperOrigin-RevId: 448273933
2 parents da6897c + 16cb032 commit 316e941

File tree

3 files changed

+110
-0
lines changed

3 files changed

+110
-0
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,11 @@ def dataset_generator2(self):
8585
for x, y in zip(self.x_train2, self.y_train2):
8686
yield np.array([x]), np.array([y])
8787

88+
def _batch(self, dims, batch_size):
89+
if dims[0] is None:
90+
dims[0] = batch_size
91+
return dims
92+
8893
def end_to_end_testing(self, original_model, clusters_check=None):
8994
"""Test End to End clustering."""
9095

@@ -225,6 +230,80 @@ def testSparsityIsPreservedDuringTraining(self):
225230
nr_of_unique_weights_after,
226231
clustering_params["number_of_clusters"])
227232

233+
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
234+
def testEndToEndSequential(self):
235+
"""Test End to End clustering - sequential model."""
236+
original_model = keras.Sequential([
237+
layers.Dense(5, input_shape=(5,)),
238+
layers.Dense(5),
239+
])
240+
241+
def clusters_check(stripped_model):
242+
# dense layer
243+
weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist()
244+
unique_weights = set(weights_as_list)
245+
self.assertLessEqual(
246+
len(unique_weights), self.params["number_of_clusters"])
247+
248+
self.end_to_end_testing(original_model, clusters_check)
249+
250+
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
251+
def testEndToEndConv1DAndConv1DTranspose(self):
252+
"""Test End to End clustering - model with Conv1D and Conv1DTranspose."""
253+
inp = layers.Input(batch_shape=(1, 16))
254+
x = layers.Conv1D(
255+
10, 16, 4, padding="valid", use_bias=False)(
256+
tf.expand_dims(inp, axis=-1))
257+
y = layers.Conv1DTranspose(1, 16, 4, padding="valid", use_bias=False)(x)
258+
model = keras.models.Model(inputs=inp, outputs=[y])
259+
260+
def apply_clustering(layer):
261+
if isinstance(layer, keras.layers.Conv1D) or isinstance(
262+
layer, keras.layers.Conv1DTranspose):
263+
return cluster.cluster_weights(layer, **self.params)
264+
return layer
265+
266+
model_to_cluster = keras.models.clone_model(
267+
model,
268+
clone_function=apply_clustering,
269+
)
270+
271+
model_to_cluster.compile(
272+
loss=keras.losses.categorical_crossentropy,
273+
optimizer="adam",
274+
metrics=["accuracy"]
275+
)
276+
model_to_cluster.fit(
277+
np.random.randn(*self._batch(model.input.get_shape().as_list(), 16)),
278+
np.random.randn(*self._batch(model.output.get_shape().as_list(), 16)),
279+
steps_per_epoch=1)
280+
clustered_model = cluster.strip_clustering(model_to_cluster)
281+
282+
def do_checks(layer, layer_name):
283+
self.assertEqual(layer.name, layer_name)
284+
unique_weights = np.unique(layer.weights[0].numpy().flatten())
285+
self.assertLessEqual(
286+
len(unique_weights), self.params["number_of_clusters"])
287+
288+
do_checks(clustered_model.layers[2], "conv1d")
289+
do_checks(clustered_model.layers[3], "conv1d_transpose")
290+
291+
def testStripClusteringSequentialModelWithRegulariser(self):
292+
"""Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
293+
original_model = keras.Sequential([
294+
layers.Dense(5, input_shape=(5,)),
295+
layers.Dense(5, kernel_regularizer=tf.keras.regularizers.L1(0.01)),
296+
])
297+
298+
def clusters_check(stripped_model):
299+
# dense layer
300+
weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist()
301+
unique_weights = set(weights_as_list)
302+
self.assertLessEqual(
303+
len(unique_weights), self.params["number_of_clusters"])
304+
305+
self.end_to_end_testing(original_model, clusters_check)
306+
228307
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
229308
def testEndToEndFunctional(self):
230309
"""Test End to End clustering - functional model."""

tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,11 @@ def setUp(self):
127127
self.keras_depthwiseconv2d_layer = layers.DepthwiseConv2D((3, 3), (1, 1))
128128
self.keras_dense_layer = layers.Dense(10)
129129
self.keras_conv1d_layer = layers.Conv1D(filters=3, kernel_size=(5))
130+
self.keras_conv1d_tr_layer = layers.Conv1DTranspose(
131+
filters=3, kernel_size=(5))
130132
self.keras_conv2d_layer = layers.Conv2D(filters=3, kernel_size=(4, 5))
133+
self.keras_conv2d_tr_layer = layers.Conv2DTranspose(
134+
filters=3, kernel_size=(4, 5))
131135
self.keras_conv3d_layer = layers.Conv3D(filters=2, kernel_size=(3, 4, 5))
132136
self.keras_custom_layer = KerasCustomLayer()
133137
self.clusterable_layer = MyClusterableLayer(10)
@@ -223,6 +227,19 @@ def testConv1DLayer(self):
223227
self.assertEqual([5, 1, 3],
224228
wrapped_layer.layer.get_clusterable_weights()[0][1].shape)
225229

230+
@keras_parameterized.run_all_keras_modes
231+
def testConv1DTransposeLayer(self):
232+
"""Verifies that we can cluster a Conv1DTranspose layer."""
233+
input_shape = (4, 28, 1)
234+
wrapped_layer = self._build_clustered_layer_model(
235+
self.keras_conv1d_tr_layer,
236+
input_shape=input_shape)
237+
238+
self._validate_clustered_layer(self.keras_conv1d_tr_layer,
239+
wrapped_layer)
240+
self.assertEqual([5, 3, 1],
241+
wrapped_layer.layer.get_clusterable_weights()[0][1].shape)
242+
226243
@keras_parameterized.run_all_keras_modes
227244
def testConv2DLayer(self):
228245
"""Verifies that we can cluster a Conv2D layer."""
@@ -236,6 +253,19 @@ def testConv2DLayer(self):
236253
self.assertEqual([4, 5, 1, 3],
237254
wrapped_layer.layer.get_clusterable_weights()[0][1].shape)
238255

256+
@keras_parameterized.run_all_keras_modes
257+
def testConv2DTransposeLayer(self):
258+
"""Verifies that we can cluster a Conv2DTranspose layer."""
259+
input_shape = (4, 28, 28, 1)
260+
wrapped_layer = self._build_clustered_layer_model(
261+
self.keras_conv2d_tr_layer,
262+
input_shape=input_shape)
263+
264+
self._validate_clustered_layer(self.keras_conv2d_tr_layer,
265+
wrapped_layer)
266+
self.assertEqual([4, 5, 3, 1],
267+
wrapped_layer.layer.get_clusterable_weights()[0][1].shape)
268+
239269
@keras_parameterized.run_all_keras_modes
240270
def testConv3DLayer(self):
241271
"""Verifies that we can cluster a Conv3D layer."""

tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class ClusteringRegistry(object):
6262
# allows the wrapper to access and modify the weights.
6363
_LAYERS_WEIGHTS_MAP = {
6464
layers.Conv1D: ['kernel'],
65+
layers.Conv1DTranspose: ['kernel'],
6566
layers.Conv2D: ['kernel'],
6667
layers.Conv2DTranspose: ['kernel'],
6768
layers.Conv3D: ['kernel'],

0 commit comments

Comments
 (0)