@@ -53,6 +53,25 @@ def get_clusterable_weights(self):
5353class CustomNonClusterableLayer (layers .Dense ):
5454 pass
5555
56+ class KerasCustomLayer (keras .layers .Layer ):
57+ def __init__ (self , units = 32 ):
58+ super (KerasCustomLayer , self ).__init__ ()
59+ self .units = units
60+
61+ def build (self , input_shape ):
62+ self .w = self .add_weight (
63+ shape = (input_shape [- 1 ], self .units ),
64+ initializer = "random_normal" ,
65+ trainable = True ,
66+ )
67+ self .b = self .add_weight (
68+ shape = (self .units ,),
69+ initializer = "random_normal" ,
70+ trainable = False
71+ )
72+
73+ def call (self , inputs ):
74+ return tf .matmul (inputs , self .w ) + self .b
5675
5776class ClusterTest (test .TestCase , parameterized .TestCase ):
5877 """Unit tests for the cluster module."""
@@ -66,6 +85,7 @@ def setUp(self):
6685 self .custom_clusterable_layer = CustomClusterableLayer (10 )
6786 self .custom_non_clusterable_layer = CustomNonClusterableLayer (10 )
6887 self .keras_depthwiseconv2d_layer = layers .DepthwiseConv2D ((3 , 3 ), (1 , 1 ))
88+ self .keras_custom_layer = KerasCustomLayer ()
6989
7090 clustering_registry .ClusteringLookupRegistry .register_new_implementation (
7191 {
@@ -179,6 +199,22 @@ def testClusterCustomNonClusterableLayer(self):
179199 cluster_wrapper .ClusterWeights (custom_non_clusterable_layer ,
180200 ** self .params )
181201
202+ def testClusterKerasCustomLayer (self ):
203+ """
204+ Verifies that attempting to cluster a keras custom layer raises
205+ an exception.
206+ """
207+ # If layer is not built, it has not weights, so
208+ # we just skip it.
209+ keras_custom_layer = self .keras_custom_layer
210+ cluster_wrapper .ClusterWeights (keras_custom_layer ,
211+ ** self .params )
212+ # We need to build weights before check that clustering is not supported.
213+ keras_custom_layer .build (input_shape = (10 , 10 ))
214+ with self .assertRaises (ValueError ):
215+ cluster_wrapper .ClusterWeights (keras_custom_layer ,
216+ ** self .params )
217+
182218 @keras_parameterized .run_all_keras_modes
183219 def testClusterSequentialModelSelectively (self ):
184220 """
0 commit comments