@@ -124,6 +124,7 @@ def setUp(self):
124124 self .custom_clusterable_layer = CustomClusterableLayer (10 )
125125 self .custom_non_clusterable_layer = CustomNonClusterableLayer (10 )
126126 self .keras_depthwiseconv2d_layer = layers .DepthwiseConv2D ((3 , 3 ), (1 , 1 ))
127+ self .clusterable_layer = MyClusterableLayer (10 )
127128 self .keras_custom_layer = KerasCustomLayer ()
128129 self .clusterable_layer = MyClusterableLayer (10 )
129130
@@ -242,6 +243,34 @@ def testClusterCustomNonClusterableLayer(self):
242243 cluster_wrapper .ClusterWeights (custom_non_clusterable_layer ,
243244 ** self .params )
244245
246+ def testClusterMyClusterableLayer (self ):
247+ # we have weights to cluster.
248+ clusterable_layer = self .clusterable_layer
249+ clusterable_layer .build (input_shape = (10 , 10 ))
250+
251+ wrapped_layer = cluster_wrapper .ClusterWeights (clusterable_layer ,
252+ ** self .params )
253+
254+ self .assertIsInstance (wrapped_layer , cluster_wrapper .ClusterWeights )
255+
256+ def testKerasCustomLayerClusterable (self ):
257+ """
258+ Verifies that we can wrap keras custom layer that is customerable.
259+ """
260+ clusterable_layer = KerasCustomLayerClusterable ()
261+ wrapped_layer = cluster_wrapper .ClusterWeights (clusterable_layer ,
262+ ** self .params )
263+
264+ self .assertIsInstance (wrapped_layer , cluster_wrapper .ClusterWeights )
265+
266+ def testClusterMyClusterableLayerInvalid (self ):
267+ """
268+ Verifies that assertion is thrown when function
269+ get_clusterable_weights is not provided.
270+ """
271+ with self .assertRaises (TypeError ):
272+ MyClusterableLayerInvalid (10 ) # pylint: disable=abstract-class-instantiated
273+
245274 def testClusterKerasCustomLayer (self ):
246275 """Verifies that attempting to cluster a keras custom layer raises an exception."""
247276 # If layer is not built, it has not weights, so
@@ -275,8 +304,7 @@ def testClusterMyClusterableLayerInvalid(self):
275304 @keras_parameterized .run_all_keras_modes
276305 def testClusterSequentialModelSelectively (self ):
277306 clustered_model = keras .Sequential ()
278- clustered_model .add (
279- cluster .cluster_weights (self .keras_clusterable_layer , ** self .params ))
307+ clustered_model .add (cluster .cluster_weights (self .keras_clusterable_layer , ** self .params ))
280308 clustered_model .add (self .keras_clusterable_layer )
281309 clustered_model .build (input_shape = (1 , 10 ))
282310
0 commit comments