Skip to content

Commit 16cb032

Browse files
committed
Added support for Conv1DTranspose + tests.
Change-Id: I0e31a73869dd2e2e6fa374dc8471f7ea1d88f605
1 parent a20982e commit 16cb032

File tree

3 files changed

+107
-0
lines changed

3 files changed

+107
-0
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,11 @@ def dataset_generator2(self):
8585
for x, y in zip(self.x_train2, self.y_train2):
8686
yield np.array([x]), np.array([y])
8787

88+
def _batch(self, dims, batch_size):
89+
if dims[0] is None:
90+
dims[0] = batch_size
91+
return dims
92+
8893
def end_to_end_testing(self, original_model, clusters_check=None):
8994
"""Test End to End clustering."""
9095

@@ -225,6 +230,79 @@ def testSparsityIsPreservedDuringTraining(self):
225230
nr_of_unique_weights_after,
226231
clustering_params["number_of_clusters"])
227232

233+
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
234+
def testEndToEndSequential(self):
235+
"""Test End to End clustering - sequential model."""
236+
original_model = keras.Sequential([
237+
layers.Dense(5, input_shape=(5,)),
238+
layers.Dense(5),
239+
])
240+
241+
def clusters_check(stripped_model):
242+
# dense layer
243+
weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist()
244+
unique_weights = set(weights_as_list)
245+
self.assertLessEqual(
246+
len(unique_weights), self.params["number_of_clusters"])
247+
248+
self.end_to_end_testing(original_model, clusters_check)
249+
250+
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
251+
def testEndToEndConv1DAndConv1DTranspose(self):
252+
"""Test End to End clustering - model with Conv1D and Conv1DTranspose."""
253+
inp = layers.Input(batch_shape=(1, 16))
254+
x = layers.Conv1D(10, 16, 4, padding = "valid", use_bias=False)(tf.expand_dims(inp, axis=-1))
255+
y = layers.Conv1DTranspose(1, 16, 4, padding = "valid", use_bias=False)(x)
256+
model = keras.models.Model(inputs=inp, outputs=[y])
257+
258+
def apply_clustering(layer):
259+
if isinstance(layer, keras.layers.Conv1D) or \
260+
isinstance(layer, keras.layers.Conv1DTranspose):
261+
return cluster.cluster_weights(layer, **self.params)
262+
return layer
263+
264+
model_to_cluster = keras.models.clone_model(
265+
model,
266+
clone_function=apply_clustering,
267+
)
268+
269+
model_to_cluster.compile(
270+
loss=keras.losses.categorical_crossentropy,
271+
optimizer="adam",
272+
metrics=["accuracy"]
273+
)
274+
model_to_cluster.fit(
275+
np.random.randn(*self._batch(model.input.get_shape().as_list(), 16)),
276+
np.random.randn(*self._batch(model.output.get_shape().as_list(), 16)),
277+
steps_per_epoch=1)
278+
clustered_model = cluster.strip_clustering(model_to_cluster)
279+
280+
def do_checks(layer, layer_name):
281+
self.assertEqual(layer.name, layer_name)
282+
unique_weights = np.unique(layer.weights[0].numpy().flatten())
283+
self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])
284+
285+
do_checks(clustered_model.layers[2], 'conv1d')
286+
do_checks(clustered_model.layers[3], 'conv1d_transpose')
287+
288+
def testStripClusteringSequentialModelWithRegulariser(self):
289+
"""
290+
Verifies that stripping the clustering wrappers from a sequential model
291+
produces the expected config.
292+
"""
293+
original_model = keras.Sequential([
294+
layers.Dense(5, input_shape=(5,)),
295+
layers.Dense(5, kernel_regularizer=tf.keras.regularizers.L1(0.01)),
296+
])
297+
298+
def clusters_check(stripped_model):
299+
# dense layer
300+
weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist()
301+
unique_weights = set(weights_as_list)
302+
self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])
303+
304+
self.end_to_end_testing(original_model, clusters_check)
305+
228306
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
229307
def testEndToEndFunctional(self):
230308
"""Test End to End clustering - functional model."""

tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,9 @@ def setUp(self):
127127
self.keras_depthwiseconv2d_layer = layers.DepthwiseConv2D((3, 3), (1, 1))
128128
self.keras_dense_layer = layers.Dense(10)
129129
self.keras_conv1d_layer = layers.Conv1D(filters=3, kernel_size=(5))
130+
self.keras_conv1d_tr_layer = layers.Conv1DTranspose(filters=3, kernel_size=(5))
130131
self.keras_conv2d_layer = layers.Conv2D(filters=3, kernel_size=(4, 5))
132+
self.keras_conv2d_tr_layer = layers.Conv2DTranspose(filters=3, kernel_size=(4, 5))
131133
self.keras_conv3d_layer = layers.Conv3D(filters=2, kernel_size=(3, 4, 5))
132134
self.keras_custom_layer = KerasCustomLayer()
133135
self.clusterable_layer = MyClusterableLayer(10)
@@ -223,6 +225,19 @@ def testConv1DLayer(self):
223225
self.assertEqual([5, 1, 3],
224226
wrapped_layer.layer.get_clusterable_weights()[0][1].shape)
225227

228+
@keras_parameterized.run_all_keras_modes
229+
def testConv1DTransposeLayer(self):
230+
"""Verifies that we can cluster a Conv1DTranspose layer."""
231+
input_shape = (4, 28, 1)
232+
wrapped_layer = self._build_clustered_layer_model(
233+
self.keras_conv1d_tr_layer,
234+
input_shape=input_shape)
235+
236+
self._validate_clustered_layer(self.keras_conv1d_tr_layer,
237+
wrapped_layer)
238+
self.assertEqual([5, 3, 1],
239+
wrapped_layer.layer.get_clusterable_weights()[0][1].shape)
240+
226241
@keras_parameterized.run_all_keras_modes
227242
def testConv2DLayer(self):
228243
"""Verifies that we can cluster a Conv2D layer."""
@@ -236,6 +251,19 @@ def testConv2DLayer(self):
236251
self.assertEqual([4, 5, 1, 3],
237252
wrapped_layer.layer.get_clusterable_weights()[0][1].shape)
238253

254+
@keras_parameterized.run_all_keras_modes
255+
def testConv2DTransposeLayer(self):
256+
"""Verifies that we can cluster a Conv2DTranspose layer."""
257+
input_shape = (4, 28, 28, 1)
258+
wrapped_layer = self._build_clustered_layer_model(
259+
self.keras_conv2d_tr_layer,
260+
input_shape=input_shape)
261+
262+
self._validate_clustered_layer(self.keras_conv2d_tr_layer,
263+
wrapped_layer)
264+
self.assertEqual([4, 5, 3, 1],
265+
wrapped_layer.layer.get_clusterable_weights()[0][1].shape)
266+
239267
@keras_parameterized.run_all_keras_modes
240268
def testConv3DLayer(self):
241269
"""Verifies that we can cluster a Conv3D layer."""

tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class ClusteringRegistry(object):
6262
# allows the wrapper to access and modify the weights.
6363
_LAYERS_WEIGHTS_MAP = {
6464
layers.Conv1D: ['kernel'],
65+
layers.Conv1DTranspose: ['kernel'],
6566
layers.Conv2D: ['kernel'],
6667
layers.Conv2DTranspose: ['kernel'],
6768
layers.Conv3D: ['kernel'],

0 commit comments

Comments
 (0)