Skip to content

Commit b7d7bce

Browse files
Merge pull request #616 from wwwind:clusterable_layer
PiperOrigin-RevId: 369787364
2 parents 6f072d9 + 2f9d874 commit b7d7bce

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def setUp(self):
124124
self.custom_clusterable_layer = CustomClusterableLayer(10)
125125
self.custom_non_clusterable_layer = CustomNonClusterableLayer(10)
126126
self.keras_depthwiseconv2d_layer = layers.DepthwiseConv2D((3, 3), (1, 1))
127+
self.clusterable_layer = MyClusterableLayer(10)
127128
self.keras_custom_layer = KerasCustomLayer()
128129
self.clusterable_layer = MyClusterableLayer(10)
129130

@@ -242,6 +243,34 @@ def testClusterCustomNonClusterableLayer(self):
242243
cluster_wrapper.ClusterWeights(custom_non_clusterable_layer,
243244
**self.params)
244245

246+
def testClusterMyClusterableLayer(self):
247+
# we have weights to cluster.
248+
clusterable_layer = self.clusterable_layer
249+
clusterable_layer.build(input_shape=(10, 10))
250+
251+
wrapped_layer = cluster_wrapper.ClusterWeights(clusterable_layer,
252+
**self.params)
253+
254+
self.assertIsInstance(wrapped_layer, cluster_wrapper.ClusterWeights)
255+
256+
def testKerasCustomLayerClusterable(self):
257+
"""
258+
Verifies that we can wrap keras custom layer that is customerable.
259+
"""
260+
clusterable_layer = KerasCustomLayerClusterable()
261+
wrapped_layer = cluster_wrapper.ClusterWeights(clusterable_layer,
262+
**self.params)
263+
264+
self.assertIsInstance(wrapped_layer, cluster_wrapper.ClusterWeights)
265+
266+
def testClusterMyClusterableLayerInvalid(self):
267+
"""
268+
Verifies that assertion is thrown when function
269+
get_clusterable_weights is not provided.
270+
"""
271+
with self.assertRaises(TypeError):
272+
MyClusterableLayerInvalid(10) # pylint: disable=abstract-class-instantiated
273+
245274
def testClusterKerasCustomLayer(self):
246275
"""Verifies that attempting to cluster a keras custom layer raises an exception."""
247276
# If layer is not built, it has not weights, so
@@ -275,8 +304,7 @@ def testClusterMyClusterableLayerInvalid(self):
275304
@keras_parameterized.run_all_keras_modes
276305
def testClusterSequentialModelSelectively(self):
277306
clustered_model = keras.Sequential()
278-
clustered_model.add(
279-
cluster.cluster_weights(self.keras_clusterable_layer, **self.params))
307+
clustered_model.add(cluster.cluster_weights(self.keras_clusterable_layer, **self.params))
280308
clustered_model.add(self.keras_clusterable_layer)
281309
clustered_model.build(input_shape=(1, 10))
282310

tensorflow_model_optimization/python/core/clustering/keras/clustering_algorithm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
@six.add_metaclass(abc.ABCMeta)
2323
class AbstractClusteringAlgorithm(object):
24-
"""Abstrac class to implement highly efficient vectorised look-ups.
24+
"""Abstract class to implement highly efficient vectorised look-ups.
2525
2626
We do not utilise looping for that purpose, instead we `smartly` reshape and
2727
tile arrays. The trade-off is that we are potentially using way more memory

0 commit comments

Comments
 (0)