@@ -53,6 +53,25 @@ def get_clusterable_weights(self):
53
53
class CustomNonClusterableLayer (layers .Dense ):
54
54
pass
55
55
56
+ class KerasCustomLayer (keras .layers .Layer ):
57
+ def __init__ (self , units = 32 ):
58
+ super (KerasCustomLayer , self ).__init__ ()
59
+ self .units = units
60
+
61
+ def build (self , input_shape ):
62
+ self .w = self .add_weight (
63
+ shape = (input_shape [- 1 ], self .units ),
64
+ initializer = "random_normal" ,
65
+ trainable = True ,
66
+ )
67
+ self .b = self .add_weight (
68
+ shape = (self .units ,),
69
+ initializer = "random_normal" ,
70
+ trainable = False
71
+ )
72
+
73
+ def call (self , inputs ):
74
+ return tf .matmul (inputs , self .w ) + self .b
56
75
57
76
class ClusterTest (test .TestCase , parameterized .TestCase ):
58
77
"""Unit tests for the cluster module."""
@@ -66,6 +85,7 @@ def setUp(self):
66
85
self .custom_clusterable_layer = CustomClusterableLayer (10 )
67
86
self .custom_non_clusterable_layer = CustomNonClusterableLayer (10 )
68
87
self .keras_depthwiseconv2d_layer = layers .DepthwiseConv2D ((3 , 3 ), (1 , 1 ))
88
+ self .keras_custom_layer = KerasCustomLayer ()
69
89
70
90
clustering_registry .ClusteringLookupRegistry .register_new_implementation (
71
91
{
@@ -179,6 +199,22 @@ def testClusterCustomNonClusterableLayer(self):
179
199
cluster_wrapper .ClusterWeights (custom_non_clusterable_layer ,
180
200
** self .params )
181
201
202
+ def testClusterKerasCustomLayer (self ):
203
+ """
204
+ Verifies that attempting to cluster a keras custom layer raises
205
+ an exception.
206
+ """
207
+ # If layer is not built, it has not weights, so
208
+ # we just skip it.
209
+ keras_custom_layer = self .keras_custom_layer
210
+ cluster_wrapper .ClusterWeights (keras_custom_layer ,
211
+ ** self .params )
212
+ # We need to build weights before check that clustering is not supported.
213
+ keras_custom_layer .build (input_shape = (10 , 10 ))
214
+ with self .assertRaises (ValueError ):
215
+ cluster_wrapper .ClusterWeights (keras_custom_layer ,
216
+ ** self .params )
217
+
182
218
@keras_parameterized .run_all_keras_modes
183
219
def testClusterSequentialModelSelectively (self ):
184
220
"""
0 commit comments