@@ -54,6 +54,25 @@ def get_clusterable_weights(self):
5454class CustomNonClusterableLayer (layers .Dense ):
5555 pass
5656
57+ class KerasCustomLayer (keras .layers .Layer ):
58+ def __init__ (self , units = 32 ):
59+ super (KerasCustomLayer , self ).__init__ ()
60+ self .units = units
61+
62+ def build (self , input_shape ):
63+ self .w = self .add_weight (
64+ shape = (input_shape [- 1 ], self .units ),
65+ initializer = "random_normal" ,
66+ trainable = True ,
67+ )
68+ self .b = self .add_weight (
69+ shape = (self .units ,),
70+ initializer = "random_normal" ,
71+ trainable = False
72+ )
73+
74+ def call (self , inputs ):
75+ return tf .matmul (inputs , self .w ) + self .b
5776
5877class ClusterTest (test .TestCase , parameterized .TestCase ):
5978 """Unit tests for the cluster module."""
@@ -67,6 +86,7 @@ def setUp(self):
6786 self .custom_clusterable_layer = CustomClusterableLayer (10 )
6887 self .custom_non_clusterable_layer = CustomNonClusterableLayer (10 )
6988 self .keras_depthwiseconv2d_layer = layers .DepthwiseConv2D ((3 , 3 ), (1 , 1 ))
89+ self .keras_custom_layer = KerasCustomLayer ()
7090
7191 clustering_registry .ClusteringLookupRegistry .register_new_implementation (
7292 {
@@ -183,6 +203,22 @@ def testClusterCustomNonClusterableLayer(self):
183203 cluster_wrapper .ClusterWeights (custom_non_clusterable_layer ,
184204 ** self .params )
185205
206+ def testClusterKerasCustomLayer (self ):
207+ """
208+ Verifies that attempting to cluster a keras custom layer raises
209+ an exception.
210+ """
211+ # If layer is not built, it has not weights, so
212+ # we just skip it.
213+ keras_custom_layer = self .keras_custom_layer
214+ cluster_wrapper .ClusterWeights (keras_custom_layer ,
215+ ** self .params )
216+ # We need to build weights before check that clustering is not supported.
217+ keras_custom_layer .build (input_shape = (10 , 10 ))
218+ with self .assertRaises (ValueError ):
219+ cluster_wrapper .ClusterWeights (keras_custom_layer ,
220+ ** self .params )
221+
186222 @keras_parameterized .run_all_keras_modes
187223 def testClusterSequentialModelSelectively (self ):
188224 """Verifies that layers within a sequential model can be clustered selectively."""
0 commit comments