@@ -54,6 +54,25 @@ def get_clusterable_weights(self):
54
54
class CustomNonClusterableLayer (layers .Dense ):
55
55
pass
56
56
57
+ class KerasCustomLayer (keras .layers .Layer ):
58
+ def __init__ (self , units = 32 ):
59
+ super (KerasCustomLayer , self ).__init__ ()
60
+ self .units = units
61
+
62
+ def build (self , input_shape ):
63
+ self .w = self .add_weight (
64
+ shape = (input_shape [- 1 ], self .units ),
65
+ initializer = "random_normal" ,
66
+ trainable = True ,
67
+ )
68
+ self .b = self .add_weight (
69
+ shape = (self .units ,),
70
+ initializer = "random_normal" ,
71
+ trainable = False
72
+ )
73
+
74
+ def call (self , inputs ):
75
+ return tf .matmul (inputs , self .w ) + self .b
57
76
58
77
class ClusterTest (test .TestCase , parameterized .TestCase ):
59
78
"""Unit tests for the cluster module."""
@@ -67,6 +86,7 @@ def setUp(self):
67
86
self .custom_clusterable_layer = CustomClusterableLayer (10 )
68
87
self .custom_non_clusterable_layer = CustomNonClusterableLayer (10 )
69
88
self .keras_depthwiseconv2d_layer = layers .DepthwiseConv2D ((3 , 3 ), (1 , 1 ))
89
+ self .keras_custom_layer = KerasCustomLayer ()
70
90
71
91
clustering_registry .ClusteringLookupRegistry .register_new_implementation (
72
92
{
@@ -183,6 +203,22 @@ def testClusterCustomNonClusterableLayer(self):
183
203
cluster_wrapper .ClusterWeights (custom_non_clusterable_layer ,
184
204
** self .params )
185
205
206
+ def testClusterKerasCustomLayer (self ):
207
+ """
208
+ Verifies that attempting to cluster a keras custom layer raises
209
+ an exception.
210
+ """
211
+ # If layer is not built, it has not weights, so
212
+ # we just skip it.
213
+ keras_custom_layer = self .keras_custom_layer
214
+ cluster_wrapper .ClusterWeights (keras_custom_layer ,
215
+ ** self .params )
216
+ # We need to build weights before check that clustering is not supported.
217
+ keras_custom_layer .build (input_shape = (10 , 10 ))
218
+ with self .assertRaises (ValueError ):
219
+ cluster_wrapper .ClusterWeights (keras_custom_layer ,
220
+ ** self .params )
221
+
186
222
@keras_parameterized .run_all_keras_modes
187
223
def testClusterSequentialModelSelectively (self ):
188
224
"""Verifies that layers within a sequential model can be clustered selectively."""
0 commit comments