Skip to content

Commit 9193d70

Browse files
Merge pull request #596 from wwwind:test_keras_layer
PiperOrigin-RevId: 365525089
2 parents c27c9ca + 2479fad commit 9193d70

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,25 @@ def get_clusterable_weights(self):
5454
class CustomNonClusterableLayer(layers.Dense):
5555
pass
5656

57+
class KerasCustomLayer(keras.layers.Layer):
58+
def __init__(self, units=32):
59+
super(KerasCustomLayer, self).__init__()
60+
self.units = units
61+
62+
def build(self, input_shape):
63+
self.w = self.add_weight(
64+
shape=(input_shape[-1], self.units),
65+
initializer="random_normal",
66+
trainable=True,
67+
)
68+
self.b = self.add_weight(
69+
shape=(self.units,),
70+
initializer="random_normal",
71+
trainable=False
72+
)
73+
74+
def call(self, inputs):
75+
return tf.matmul(inputs, self.w) + self.b
5776

5877
class ClusterTest(test.TestCase, parameterized.TestCase):
5978
"""Unit tests for the cluster module."""
@@ -67,6 +86,7 @@ def setUp(self):
6786
self.custom_clusterable_layer = CustomClusterableLayer(10)
6887
self.custom_non_clusterable_layer = CustomNonClusterableLayer(10)
6988
self.keras_depthwiseconv2d_layer = layers.DepthwiseConv2D((3, 3), (1, 1))
89+
self.keras_custom_layer = KerasCustomLayer()
7090

7191
clustering_registry.ClusteringLookupRegistry.register_new_implementation(
7292
{
@@ -183,6 +203,22 @@ def testClusterCustomNonClusterableLayer(self):
183203
cluster_wrapper.ClusterWeights(custom_non_clusterable_layer,
184204
**self.params)
185205

206+
def testClusterKerasCustomLayer(self):
207+
"""
208+
Verifies that attempting to cluster a keras custom layer raises
209+
an exception.
210+
"""
211+
# If layer is not built, it has not weights, so
212+
# we just skip it.
213+
keras_custom_layer = self.keras_custom_layer
214+
cluster_wrapper.ClusterWeights(keras_custom_layer,
215+
**self.params)
216+
# We need to build weights before check that clustering is not supported.
217+
keras_custom_layer.build(input_shape=(10, 10))
218+
with self.assertRaises(ValueError):
219+
cluster_wrapper.ClusterWeights(keras_custom_layer,
220+
**self.params)
221+
186222
@keras_parameterized.run_all_keras_modes
187223
def testClusterSequentialModelSelectively(self):
188224
"""Verifies that layers within a sequential model can be clustered selectively."""

0 commit comments

Comments
 (0)