Skip to content

Commit 1155eee

Browse files
committed
Clusterable layer API.
Change-Id: I9bb372717923348d4c42fdecf7eeea64ce44f49b
1 parent e4a5200 commit 1155eee

File tree

7 files changed

+451
-27
lines changed

7 files changed

+451
-27
lines changed

tensorflow_model_optimization/python/core/clustering/keras/BUILD

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,14 @@ py_test(
164164
"//tensorflow_model_optimization/python/core/keras:test_utils",
165165
],
166166
)
167+
168+
py_test(
169+
name = "mnist_customerable_test",
170+
srcs = ["mnist_customerable_test.py"],
171+
python_version = "PY3",
172+
visibility = ["//visibility:public"],
173+
deps = [
174+
":cluster"
175+
# tensorflow dep1,
176+
],
177+
)

tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,46 @@ class CustomNonClusterableLayer(layers.Dense):
5555
pass
5656

5757

58+
class MyCustomerableLayer(keras.layers.Dense,
59+
clusterable_layer.ClusterableLayer):
60+
61+
def __init__(self, num_units):
62+
super().__init__(num_units)
63+
64+
def get_clusterable_weights(self):
65+
# Cluster kernel and bias.
66+
return [('kernel', self.kernel), ('bias', self.bias)]
67+
68+
class MyCustomerableLayerInvalid(keras.layers.Dense,
69+
clusterable_layer.ClusterableLayer):
70+
""" This layer is invalid, because it does not provide
71+
get_clusterable_weights function.
72+
"""
73+
def __init__(self, num_units):
74+
super().__init__(num_units)
75+
76+
class TestCustomerableWeightsCA(clustering_registry.AbstractClusteringAlgorithm):
77+
""" Dummy class derived from AbstractClusteringAlgorithm."""
78+
def get_pulling_indices(self, weight):
79+
return [1, 2, 3]
80+
81+
class KerasCustomLayerClusterable(keras.layers.Layer,
82+
clusterable_layer.ClusterableLayer):
83+
""" This keras custom layer is derived from ClusterableLayer
84+
and it provides own implementation of the clustering
85+
algorithm.
86+
"""
87+
88+
def __init__(self):
89+
super().__init__()
90+
self.kernel = None
91+
92+
def get_clusterable_weights(self):
93+
return [('kernel', self.kernel)]
94+
95+
def get_clusterable_algorithm(self, weight_name):
96+
return TestCustomerableWeightsCA
97+
5898
class ClusterTest(test.TestCase, parameterized.TestCase):
5999
"""Unit tests for the cluster module."""
60100

@@ -67,6 +107,8 @@ def setUp(self):
67107
self.custom_clusterable_layer = CustomClusterableLayer(10)
68108
self.custom_non_clusterable_layer = CustomNonClusterableLayer(10)
69109
self.keras_depthwiseconv2d_layer = layers.DepthwiseConv2D((3, 3), (1, 1))
110+
self.customerable_layer = MyCustomerableLayer(10)
111+
self.keras_custom_layer = KerasCustomLayer()
70112

71113
clustering_registry.ClusteringLookupRegistry.register_new_implementation(
72114
{
@@ -183,12 +225,54 @@ def testClusterCustomNonClusterableLayer(self):
183225
cluster_wrapper.ClusterWeights(custom_non_clusterable_layer,
184226
**self.params)
185227

228+
def testClusterMyCustomerableLayer(self):
229+
# we have weights to cluster.
230+
customerable_layer = self.customerable_layer
231+
customerable_layer.build(input_shape=(10, 10))
232+
233+
wrapped_layer = cluster_wrapper.ClusterWeights(customerable_layer,
234+
**self.params)
235+
236+
self.assertIsInstance(wrapped_layer, cluster_wrapper.ClusterWeights)
237+
238+
def testKerasCustomLayerClusterable(self):
239+
"""
240+
Verifies that we can wrap keras custom layer that is customerable.
241+
"""
242+
customerable_layer = KerasCustomLayerClusterable()
243+
wrapped_layer = cluster_wrapper.ClusterWeights(customerable_layer,
244+
**self.params)
245+
246+
self.assertIsInstance(wrapped_layer, cluster_wrapper.ClusterWeights)
247+
248+
def testClusterMyCustomerableLayerInvalid(self):
249+
"""
250+
Verifies that assertion is thrown when function
251+
get_clusterable_weights is not provided.
252+
"""
253+
with self.assertRaises(TypeError):
254+
MyCustomerableLayerInvalid(10) # pylint: disable=abstract-class-instantiated
255+
256+
def testClusterKerasCustomLayer(self):
257+
"""
258+
Verifies that attempting to cluster a keras custom layer raises
259+
an exception.
260+
"""
261+
# If layer is not built, it has not weights, so
262+
# we just skip it.
263+
keras_custom_layer = self.keras_custom_layer
264+
cluster_wrapper.ClusterWeights(keras_custom_layer,
265+
**self.params)
266+
# We need to build weights before check that clustering is not supported.
267+
keras_custom_layer.build(input_shape=(10, 10))
268+
with self.assertRaises(ValueError):
269+
cluster_wrapper.ClusterWeights(keras_custom_layer,
270+
**self.params)
271+
272+
>>>>>>> 8fe29ec... MLTOOLS-1031 Customerable layer API.
186273
@keras_parameterized.run_all_keras_modes
187274
def testClusterSequentialModelSelectively(self):
188-
"""Verifies that layers within a sequential model can be clustered selectively."""
189275
clustered_model = keras.Sequential()
190-
clustered_model.add(
191-
cluster.cluster_weights(self.keras_clusterable_layer, **self.params))
192276
clustered_model.add(self.keras_clusterable_layer)
193277
clustered_model.build(input_shape=(1, 10))
194278

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,10 @@ def build(self, input_shape):
178178
# variable is either in the self._trainable_weights or in
179179
# self._non_trainable_weights and self.weights is the result of
180180
# concatenation of those arrays
181-
original_index = self.layer.weights.index(weight)
181+
original_index = 0
182+
for i in range(len(self.layer.weights)):
183+
if self.layer.weights[i].name == weight.name:
184+
original_index = i
182185
self.gone_variables.append(original_index)
183186

184187
# Again, not sure if this is needed. Leaving for now.

tensorflow_model_optimization/python/core/clustering/keras/clusterable_layer.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,13 @@
2222
class ClusterableLayer:
2323
"""Abstract Base Class for making your own keras layer clusterable.
2424
25-
Custom keras layers that need to support clustering should implement this
26-
class.
25+
Your layer could be derived from a keras built-in layer or
26+
it could be a keras custom layer.
27+
28+
The function get_clusterable_weights should be provided in both cases.
29+
30+
The function get_clusterable_algorithm is provided, when weights for
31+
clustering is added in the keras layer.
2732
2833
"""
2934

@@ -40,3 +45,21 @@ def get_clusterable_weights(self):
4045
kernel object itself.
4146
"""
4247
raise NotImplementedError('Must be implemented in subclasses.')
48+
49+
def get_clusterable_algorithm(self, weight_name):
50+
"""Returns class with the clustering algorithm for the given weight_name.
51+
52+
This function needs to be implemented for the customerable layers.
53+
If the layer is derived from the built-in keras layer, the clustering
54+
algorithm for the base built-in keras layer is used.
55+
56+
The returned class should be derived from AbstractClusteringAlgorithm and
57+
implements the function get_pulling_indices.
58+
This function is used to provide a special lookup function for the custom weights.
59+
It reshapes and tile centroids the same way as the weights. This allows us
60+
to find pulling indices efficiently.
61+
62+
Args:
63+
weight_name ([string]): The name of the weight variable.
64+
"""
65+
return None

tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,6 @@ def get_pulling_indices(self, weight):
170170

171171
return pulling_indices
172172

173-
174173
class ClusteringLookupRegistry(object):
175174
"""
176175
The keys represent built-in keras layers and the values represent the
@@ -179,17 +178,17 @@ class ClusteringLookupRegistry(object):
179178
work on, or the strategy is not currently supported
180179
"""
181180
_LAYERS_RESHAPE_MAP = {
182-
layers.Conv1D: {'kernel': ConvolutionalWeightsCA},
183-
layers.Conv2D: {'kernel': ConvolutionalWeightsCA},
184-
layers.Conv2DTranspose: {'kernel': ConvolutionalWeightsCA},
185-
layers.Conv3D: {'kernel': ConvolutionalWeightsCA},
186-
layers.Conv3DTranspose: {'kernel': ConvolutionalWeightsCA},
187-
layers.SeparableConv1D: {'pointwise_kernel': ConvolutionalWeightsCA},
188-
layers.SeparableConv2D: {'pointwise_kernel': ConvolutionalWeightsCA},
189-
layers.Dense: {'kernel': DenseWeightsCA},
190-
layers.Embedding: {'embeddings': DenseWeightsCA},
191-
layers.LocallyConnected1D: {'kernel': ConvolutionalWeightsCA},
192-
layers.LocallyConnected2D: {'kernel': ConvolutionalWeightsCA},
181+
layers.Conv1D: {'kernel': ConvolutionalWeightsCA, 'bias': BiasWeightsCA},
182+
layers.Conv2D: {'kernel': ConvolutionalWeightsCA, 'bias': BiasWeightsCA},
183+
layers.Conv2DTranspose: {'kernel': ConvolutionalWeightsCA, 'bias': BiasWeightsCA},
184+
layers.Conv3D: {'kernel': ConvolutionalWeightsCA, 'bias': BiasWeightsCA},
185+
layers.Conv3DTranspose: {'kernel': ConvolutionalWeightsCA, 'bias': BiasWeightsCA},
186+
layers.SeparableConv1D: {'pointwise_kernel': ConvolutionalWeightsCA, 'bias': BiasWeightsCA},
187+
layers.SeparableConv2D: {'pointwise_kernel': ConvolutionalWeightsCA, 'bias': BiasWeightsCA},
188+
layers.Dense: {'kernel': DenseWeightsCA, 'bias': BiasWeightsCA},
189+
layers.Embedding: {'embeddings': DenseWeightsCA, 'bias': BiasWeightsCA},
190+
layers.LocallyConnected1D: {'kernel': ConvolutionalWeightsCA, 'bias': BiasWeightsCA},
191+
layers.LocallyConnected2D: {'kernel': ConvolutionalWeightsCA, 'bias': BiasWeightsCA},
193192
}
194193

195194
@classmethod
@@ -229,15 +228,37 @@ def get_clustering_impl(cls, layer, weight_name):
229228
:param weight_name: concrete weight name to be clustered.
230229
:return: a concrete implementation of a lookup algorithm
231230
"""
231+
custom_layer_of_built_layer = None
232232
if not layer.__class__ in cls._LAYERS_RESHAPE_MAP:
233-
raise ValueError(
234-
"Class {given_class} has not been registerd in the"
235-
"ClusteringLookupRegistry. Use ClusteringLookupRegistry."
236-
"register_new_implemenetation to fix this.".format(
237-
given_class=layer.__class__
238-
)
239-
)
240-
if weight_name not in cls._LAYERS_RESHAPE_MAP[layer.__class__]:
233+
# Checks whether we have a customerable layer derived from built-in keras class.
234+
for key in cls._LAYERS_RESHAPE_MAP:
235+
if issubclass(layer.__class__, key):
236+
custom_layer_of_built_layer = key
237+
if not custom_layer_of_built_layer:
238+
# Checks whether we have a customerable layer that provides
239+
# clusterable algorithm for the given weights.
240+
if issubclass(layer.__class__, clusterable_layer.ClusterableLayer) and \
241+
layer.get_clusterable_algorithm is not None:
242+
ans = layer.get_clusterable_algorithm(weight_name)
243+
if not ans:
244+
raise ValueError(
245+
"Class {given_class} does not provided clustering algorithm"
246+
"for the weights with the name {weight_name}.".format(
247+
given_class=layer.__class__, weight_name=weight_name
248+
)
249+
)
250+
else:
251+
return ans
252+
else:
253+
raise ValueError(
254+
"Class {given_class} has not derived from ClusterableLayer"
255+
"or the funtion get_pulling_indices is not provided.".format(
256+
given_class=layer.__class__
257+
)
258+
)
259+
else:
260+
custom_layer_of_built_layer = layer.__class__
261+
if weight_name not in cls._LAYERS_RESHAPE_MAP[custom_layer_of_built_layer]:
241262
raise ValueError(
242263
"Weight with the name '{given_weight_name}' for class {given_class} "
243264
"has not been registerd in the ClusteringLookupRegistry. Use "
@@ -249,7 +270,7 @@ def get_clustering_impl(cls, layer, weight_name):
249270
)
250271
# Different weights will have different shapes hence there is double hash
251272
# map lookup.
252-
return cls._LAYERS_RESHAPE_MAP[layer.__class__][weight_name]
273+
return cls._LAYERS_RESHAPE_MAP[custom_layer_of_built_layer][weight_name]
253274

254275

255276
class ClusteringRegistry(object):

tensorflow_model_optimization/python/core/clustering/keras/clustering_registry_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,42 @@ def testGetClusteringImplFailsWithKnonwClassUnknownWeight(self):
201201
ClusteringLookupRegistry.get_clustering_impl(layers.Dense(10),
202202
'no_such_weight')
203203

204+
class KerasCustomLayerClusterableInvalid(keras.layers.Layer,
205+
clusterable_layer.ClusterableLayer):
206+
""" This keras custom layer is derived from ClusterableLayer
207+
and it provides own implementation of the clustering
208+
algorithm.
209+
"""
210+
211+
def __init__(self, units=10):
212+
super(KerasCustomLayerClusterableInvalid, self).__init__()
213+
self.units = units
214+
215+
def build(self, input_shape):
216+
self.w = self.add_weight(
217+
shape=(input_shape[-1], self.units),
218+
initializer="random_normal",
219+
trainable=True,
220+
)
221+
222+
def get_clusterable_weights(self):
223+
return [('w', self.w)]
224+
225+
def testKerasCustomLayerClusterableInvalid(self):
226+
"""
227+
Verifies that get_clustering_impl() raises an error when invoked with a
228+
keras custom layer derived from ClusterableLayer, but the function
229+
get_clustering_algorithm is not provided.
230+
"""
231+
with self.assertRaises(ValueError):
232+
ClusteringLookupRegistry.get_clustering_impl(
233+
KerasCustomLayerClusterableInvalid(), 'w')
234+
204235
@parameterized.parameters(
205236
(layers.Conv2D, 'kernel', clustering_registry.ConvolutionalWeightsCA),
206237
(layers.Conv1D, 'kernel', clustering_registry.ConvolutionalWeightsCA),
238+
(layers.Conv2D, 'bias', clustering_registry.BiasWeightsCA),
239+
(layers.Conv1D, 'bias', clustering_registry.BiasWeightsCA),
207240
)
208241
def testReturnsResultsForKnownTypeKnownWeights(self,
209242
layer_type,

0 commit comments

Comments
 (0)