Skip to content

Commit 69376e4

Browse files
authored
Merge branch 'master' into clusterable_layer
2 parents a780e56 + 92486bf commit 69376e4

File tree

23 files changed

+610
-341
lines changed

23 files changed

+610
-341
lines changed

tensorflow_model_optimization/python/core/api/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ py_library(
1313
"quantization/__init__.py",
1414
"quantization/keras/__init__.py",
1515
"quantization/keras/default_8bit/__init__.py",
16+
"quantization/keras/default_8bit/default_8bit_transforms/__init__.py",
17+
"quantization/keras/graph_transformations/__init__.py",
18+
"quantization/keras/graph_transformations/model_transformer/__init__.py",
19+
"quantization/keras/graph_transformations/transforms/__init__.py",
1620
"quantization/keras/quantizers/__init__.py",
1721
"sparsity/__init__.py",
1822
"sparsity/keras/__init__.py",

tensorflow_model_optimization/python/core/api/quantization/keras/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
# submodules
1919
from tensorflow_model_optimization.python.core.api.quantization.keras import quantizers
2020
from tensorflow_model_optimization.python.core.api.quantization.keras import default_8bit
21+
from tensorflow_model_optimization.python.core.api.quantization.keras import graph_transformations
2122

2223
# quantize all layers with default quantization implementation.
2324
from tensorflow_model_optimization.python.core.quantization.keras.quantize import quantize_model
@@ -30,6 +31,7 @@
3031
# quantize with custom quantization parameterization or implementation, or
3132
# handle custom Keras layers.
3233
from tensorflow_model_optimization.python.core.quantization.keras.quantize_config import QuantizeConfig
34+
from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
3335

3436
# Deserialize quantized model for Keras h5 format.
3537
from tensorflow_model_optimization.python.core.quantization.keras.quantize import quantize_scope

tensorflow_model_optimization/python/core/api/quantization/keras/default_8bit/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
"""Module containing 8bit default quantization scheme."""
1616
# pylint: disable=g-bad-import-order
1717

18+
# submodules
19+
from tensorflow_model_optimization.python.core.api.quantization.keras.default_8bit import default_8bit_transforms
20+
1821
# The 8bit default quantization scheme classes.
1922
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_quantize_scheme import Default8BitQuantizeScheme
2023
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_quantize_layout_transform import Default8BitQuantizeLayoutTransform
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Module containing 8bit default transforms."""
16+
17+
# The 8bit default transform classes.
18+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_transforms import ConcatTransform
19+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_transforms import ConcatTransform3Inputs
20+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_transforms import ConcatTransform4Inputs
21+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_transforms import ConcatTransform5Inputs
22+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_transforms import ConcatTransform6Inputs
23+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_transforms import Conv2DBatchNormActivationQuantize
24+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_transforms import Conv2DBatchNormQuantize
25+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_transforms import Conv2DBatchNormReLUQuantize
26+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_transforms import Conv2DReshapeBatchNormActivationQuantize
27+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_transforms import Conv2DReshapeBatchNormQuantize
28+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_transforms import Conv2DReshapeBatchNormReLUQuantize
29+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_transforms import InputLayerQuantize
30+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_transforms import LayerReluActivationQuantize
31+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_transforms import LayerReLUQuantize
32+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_transforms import SeparableConv1DQuantize
33+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_transforms import SeparableConvQuantize
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Module containing code for graph transformation."""
16+
17+
# submodules
18+
from tensorflow_model_optimization.python.core.api.quantization.keras.graph_transformations import model_transformer
19+
from tensorflow_model_optimization.python.core.api.quantization.keras.graph_transformations import transforms
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Module containing classes for model transformer."""
16+
17+
from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations.model_transformer import ModelTransformer
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Module containing classes for transform."""
16+
17+
from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations.transforms import LayerNode
18+
from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations.transforms import LayerPattern
19+
from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations.transforms import Transform

tensorflow_model_optimization/python/core/clustering/keras/BUILD

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ py_library(
4444
visibility = ["//visibility:public"],
4545
deps = [
4646
":clusterable_layer",
47-
":clustering_algorithm"
47+
":clustering_algorithm",
4848
],
4949
)
5050

@@ -183,7 +183,7 @@ py_test(
183183
python_version = "PY3",
184184
visibility = ["//visibility:public"],
185185
deps = [
186-
":cluster"
186+
":cluster",
187187
# tensorflow dep1,
188188
],
189189
)
@@ -194,7 +194,7 @@ py_test(
194194
python_version = "PY3",
195195
visibility = ["//visibility:public"],
196196
deps = [
197-
":cluster"
197+
":cluster",
198198
# tensorflow dep1,
199199
],
200-
)
200+
)

tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -54,55 +54,52 @@ def get_clusterable_weights(self):
5454
class CustomNonClusterableLayer(layers.Dense):
5555
pass
5656

57+
5758
class KerasCustomLayer(keras.layers.Layer):
59+
5860
def __init__(self, units=32):
5961
super(KerasCustomLayer, self).__init__()
6062
self.units = units
6163

6264
def build(self, input_shape):
6365
self.w = self.add_weight(
64-
shape=(input_shape[-1], self.units),
65-
initializer="random_normal",
66-
trainable=True,
66+
shape=(input_shape[-1], self.units),
67+
initializer='random_normal',
68+
trainable=True,
6769
)
6870
self.b = self.add_weight(
69-
shape=(self.units,),
70-
initializer="random_normal",
71-
trainable=False
72-
)
71+
shape=(self.units,), initializer='random_normal', trainable=False)
7372

7473
def call(self, inputs):
7574
return tf.matmul(inputs, self.w) + self.b
7675

77-
class MyClusterableLayer(keras.layers.Dense,
78-
clusterable_layer.ClusterableLayer):
7976

80-
def __init__(self, num_units):
81-
super().__init__(num_units)
77+
class MyClusterableLayer(keras.layers.Dense,
78+
clusterable_layer.ClusterableLayer):
8279

8380
def get_clusterable_weights(self):
8481
# Cluster kernel and bias.
8582
return [('kernel', self.kernel), ('bias', self.bias)]
8683

84+
8785
class MyClusterableLayerInvalid(keras.layers.Dense,
88-
clusterable_layer.ClusterableLayer):
89-
""" This layer is invalid, because it does not provide
90-
get_clusterable_weights function.
86+
clusterable_layer.ClusterableLayer):
87+
"""This layer is invalid: it does not implement get_clusterable_weights(self).
9188
"""
92-
def __init__(self, num_units):
93-
super().__init__(num_units)
89+
pass
90+
91+
92+
class TestCustomerableWeightsCA(clustering_registry.AbstractClusteringAlgorithm
93+
):
94+
"""Dummy class derived from AbstractClusteringAlgorithm."""
9495

95-
class TestCustomerableWeightsCA(clustering_registry.AbstractClusteringAlgorithm):
96-
""" Dummy class derived from AbstractClusteringAlgorithm."""
9796
def get_pulling_indices(self, weight):
9897
return [1, 2, 3]
9998

99+
100100
class KerasCustomLayerClusterable(keras.layers.Layer,
101-
clusterable_layer.ClusterableLayer):
102-
""" This keras custom layer is derived from ClusterableLayer
103-
and it provides own implementation of the clustering
104-
algorithm.
105-
"""
101+
clusterable_layer.ClusterableLayer):
102+
"""Custom Keras clusterable layer, providing its own clustering algorithm."""
106103

107104
def __init__(self):
108105
super().__init__()
@@ -114,6 +111,7 @@ def get_clusterable_weights(self):
114111
def get_clusterable_algorithm(self, weight_name):
115112
return TestCustomerableWeightsCA
116113

114+
117115
class ClusterTest(test.TestCase, parameterized.TestCase):
118116
"""Unit tests for the cluster module."""
119117

@@ -128,6 +126,7 @@ def setUp(self):
128126
self.keras_depthwiseconv2d_layer = layers.DepthwiseConv2D((3, 3), (1, 1))
129127
self.clusterable_layer = MyClusterableLayer(10)
130128
self.keras_custom_layer = KerasCustomLayer()
129+
self.clusterable_layer = MyClusterableLayer(10)
131130

132131

133132
clustering_registry.ClusteringLookupRegistry.register_new_implementation(
@@ -274,20 +273,34 @@ def testClusterMyClusterableLayerInvalid(self):
274273
MyClusterableLayerInvalid(10) # pylint: disable=abstract-class-instantiated
275274

276275
def testClusterKerasCustomLayer(self):
277-
"""
278-
Verifies that attempting to cluster a keras custom layer raises
279-
an exception.
280-
"""
276+
"""Verifies that attempting to cluster a keras custom layer raises an exception."""
281277
# If layer is not built, it has not weights, so
282278
# we just skip it.
283279
keras_custom_layer = self.keras_custom_layer
284-
cluster_wrapper.ClusterWeights(keras_custom_layer,
285-
**self.params)
280+
cluster_wrapper.ClusterWeights(keras_custom_layer, **self.params)
286281
# We need to build weights before check that clustering is not supported.
287282
keras_custom_layer.build(input_shape=(10, 10))
288283
with self.assertRaises(ValueError):
289-
cluster_wrapper.ClusterWeights(keras_custom_layer,
290-
**self.params)
284+
cluster_wrapper.ClusterWeights(keras_custom_layer, **self.params)
285+
286+
def testClusterMyClusterableLayer(self):
287+
# we have weights to cluster.
288+
layer = self.clusterable_layer
289+
layer.build(input_shape=(10, 10))
290+
291+
wrapped_layer = cluster_wrapper.ClusterWeights(layer, **self.params)
292+
self.assertIsInstance(wrapped_layer, cluster_wrapper.ClusterWeights)
293+
294+
def testKerasCustomLayerClusterable(self):
295+
"""Verifies that we can wrap keras custom layer that is customerable."""
296+
layer = KerasCustomLayerClusterable()
297+
wrapped_layer = cluster_wrapper.ClusterWeights(layer, **self.params)
298+
self.assertIsInstance(wrapped_layer, cluster_wrapper.ClusterWeights)
299+
300+
def testClusterMyClusterableLayerInvalid(self):
301+
"""Verifies that an exception is raised when get_clusterable_weights() is not implemented."""
302+
with self.assertRaises(TypeError):
303+
MyClusterableLayerInvalid(10) # pylint: disable=abstract-class-instantiated
291304

292305
@keras_parameterized.run_all_keras_modes
293306
def testClusterSequentialModelSelectively(self):

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py

Lines changed: 35 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
"""Tests for keras ClusterWeights wrapper API."""
1616

1717
import itertools
18+
import os
19+
import tempfile
1820

1921
from absl.testing import parameterized
2022
import tensorflow as tf
@@ -230,63 +232,46 @@ def assert_all_weights_associated(weights, centroid_index):
230232
# Weights should now be all clustered with the centroid 1
231233
assert_all_weights_associated(l.layer.kernel, centroid_index=1)
232234

233-
def testClusterReassociation2(self):
234-
"""Verifies that the association of weights to cluster centroids are updated every iteration."""
235-
235+
def testSameWeightsAreReturnedBeforeAndAfterSerialisation(self):
236+
"""Verify weights of cluster_wrapper are the same after serialisation."""
236237
# Create a dummy layer for this test
237238
input_shape = (1, 2,)
238-
l = cluster_wrapper.ClusterWeights(
239+
original_layer = cluster_wrapper.ClusterWeights(
239240
keras.layers.Dense(8, input_shape=input_shape),
240241
number_of_clusters=2,
241242
cluster_centroids_init=CentroidInitialization.LINEAR
242243
)
243244
# Build a layer with the given shape
244-
l.build(input_shape)
245-
246-
# Get name of the clusterable weights
247-
clusterable_weights = l.layer.get_clusterable_weights()
248-
self.assertLen(clusterable_weights, 1)
249-
weights_name = clusterable_weights[0][0]
250-
self.assertEqual(weights_name, 'kernel')
251-
# Get cluster centroids
252-
centroids = l.cluster_centroids_tf[weights_name]
253-
254-
# Calculate some statistics of the weights to set the centroids later on
255-
mean_weight = tf.reduce_mean(l.layer.kernel)
256-
min_weight = tf.reduce_min(l.layer.kernel)
257-
max_weight = tf.reduce_max(l.layer.kernel)
258-
max_dist = max_weight - min_weight
259-
260-
def assert_all_weights_associated(weights, centroid_index):
261-
"""Helper function to make sure that all weights are associated with one centroid."""
262-
all_associated = tf.reduce_all(
263-
tf.equal(
264-
weights,
265-
tf.constant(centroids[centroid_index], shape=weights.shape)
266-
)
267-
)
268-
self.assertTrue(all_associated)
269-
270-
# Set centroids so that all weights should be re-associated with centroid 0
271-
centroids[0].assign(mean_weight)
272-
centroids[1].assign(mean_weight + 2.0 * max_dist)
273-
274-
# Update associations of weights to centroids
275-
l.call(tf.ones(shape=input_shape))
276-
277-
# Weights should now be all clustered with the centroid 0
278-
assert_all_weights_associated(l.layer.kernel, centroid_index=0)
279-
280-
# Set centroids so that all weights should be re-associated with centroid 1
281-
centroids[0].assign(mean_weight - 2.0 * max_dist)
282-
centroids[1].assign(mean_weight)
283-
284-
# Update associations of weights to centroids
285-
l.call(tf.ones(shape=input_shape))
286-
287-
# Weights should now be all clustered with the centroid 1
288-
assert_all_weights_associated(l.layer.kernel, centroid_index=1)
289-
245+
original_layer.build(input_shape)
246+
model = keras.Sequential([original_layer])
247+
248+
# Save and load the layer in a temp directory
249+
with tempfile.TemporaryDirectory() as tmp_dir_name:
250+
keras_file = os.path.join(tmp_dir_name, 'keras_model')
251+
keras.models.save_model(model, keras_file)
252+
with cluster.cluster_scope():
253+
loaded_layer = keras.models.load_model(keras_file).layers[0]
254+
255+
def assert_list_of_variables_all_equal(l1, l2):
256+
self.assertLen(
257+
l1, len(l2),
258+
'lists l1 and l2 are not equal: \n l1={l1} \n l2={l2}'.format(
259+
l1=[v.name for v in l1],
260+
l2=[v.name for v in l2]))
261+
262+
name_to_var_from_l1 = {var.name: var for var in l1}
263+
for var2 in l2:
264+
self.assertIn(var2.name, name_to_var_from_l1)
265+
arr1 = name_to_var_from_l1[var2.name].numpy()
266+
arr2 = var2.numpy()
267+
self.assertAllEqual(arr1, arr2)
268+
269+
# Check that trainable_weights and non_trainable_weights are the same
270+
# in the original layer and loaded layer
271+
assert_list_of_variables_all_equal(original_layer.trainable_weights,
272+
loaded_layer.trainable_weights)
273+
assert_list_of_variables_all_equal(original_layer.non_trainable_weights,
274+
loaded_layer.non_trainable_weights)
290275

291276
if __name__ == '__main__':
292277
test.main()

0 commit comments

Comments
 (0)