Skip to content

Commit 2479fad

Browse files
committed
Added test that demonstrates that keras custom layers are not supported for clustering.
Change-Id: I235fb95986cc9219f6afb6deb95e6d7e47631476
1 parent 3a0c22d commit 2479fad

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,25 @@ def get_clusterable_weights(self):
5353
class CustomNonClusterableLayer(layers.Dense):
5454
pass
5555

56+
class KerasCustomLayer(keras.layers.Layer):
57+
def __init__(self, units=32):
58+
super(KerasCustomLayer, self).__init__()
59+
self.units = units
60+
61+
def build(self, input_shape):
62+
self.w = self.add_weight(
63+
shape=(input_shape[-1], self.units),
64+
initializer="random_normal",
65+
trainable=True,
66+
)
67+
self.b = self.add_weight(
68+
shape=(self.units,),
69+
initializer="random_normal",
70+
trainable=False
71+
)
72+
73+
def call(self, inputs):
74+
return tf.matmul(inputs, self.w) + self.b
5675

5776
class ClusterTest(test.TestCase, parameterized.TestCase):
5877
"""Unit tests for the cluster module."""
@@ -66,6 +85,7 @@ def setUp(self):
6685
self.custom_clusterable_layer = CustomClusterableLayer(10)
6786
self.custom_non_clusterable_layer = CustomNonClusterableLayer(10)
6887
self.keras_depthwiseconv2d_layer = layers.DepthwiseConv2D((3, 3), (1, 1))
88+
self.keras_custom_layer = KerasCustomLayer()
6989

7090
clustering_registry.ClusteringLookupRegistry.register_new_implementation(
7191
{
@@ -179,6 +199,22 @@ def testClusterCustomNonClusterableLayer(self):
179199
cluster_wrapper.ClusterWeights(custom_non_clusterable_layer,
180200
**self.params)
181201

202+
def testClusterKerasCustomLayer(self):
203+
"""
204+
Verifies that attempting to cluster a keras custom layer raises
205+
an exception.
206+
"""
207+
# If layer is not built, it has not weights, so
208+
# we just skip it.
209+
keras_custom_layer = self.keras_custom_layer
210+
cluster_wrapper.ClusterWeights(keras_custom_layer,
211+
**self.params)
212+
# We need to build weights before check that clustering is not supported.
213+
keras_custom_layer.build(input_shape=(10, 10))
214+
with self.assertRaises(ValueError):
215+
cluster_wrapper.ClusterWeights(keras_custom_layer,
216+
**self.params)
217+
182218
@keras_parameterized.run_all_keras_modes
183219
def testClusterSequentialModelSelectively(self):
184220
"""

0 commit comments

Comments
 (0)