Skip to content

Commit a780e56

Browse files
authored
Merge branch 'master' into clusterable_layer
2 parents cd9b2a3 + 9193d70 commit a780e56

File tree

14 files changed

+378
-15
lines changed

14 files changed

+378
-15
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,25 @@ def get_clusterable_weights(self):
5454
class CustomNonClusterableLayer(layers.Dense):
5555
pass
5656

57+
class KerasCustomLayer(keras.layers.Layer):
58+
def __init__(self, units=32):
59+
super(KerasCustomLayer, self).__init__()
60+
self.units = units
61+
62+
def build(self, input_shape):
63+
self.w = self.add_weight(
64+
shape=(input_shape[-1], self.units),
65+
initializer="random_normal",
66+
trainable=True,
67+
)
68+
self.b = self.add_weight(
69+
shape=(self.units,),
70+
initializer="random_normal",
71+
trainable=False
72+
)
73+
74+
def call(self, inputs):
75+
return tf.matmul(inputs, self.w) + self.b
5776

5877
class MyClusterableLayer(keras.layers.Dense,
5978
clusterable_layer.ClusterableLayer):
@@ -108,6 +127,8 @@ def setUp(self):
108127
self.custom_non_clusterable_layer = CustomNonClusterableLayer(10)
109128
self.keras_depthwiseconv2d_layer = layers.DepthwiseConv2D((3, 3), (1, 1))
110129
self.clusterable_layer = MyClusterableLayer(10)
130+
self.keras_custom_layer = KerasCustomLayer()
131+
111132

112133
clustering_registry.ClusteringLookupRegistry.register_new_implementation(
113134
{
@@ -252,6 +273,22 @@ def testClusterMyClusterableLayerInvalid(self):
252273
with self.assertRaises(TypeError):
253274
MyClusterableLayerInvalid(10) # pylint: disable=abstract-class-instantiated
254275

276+
def testClusterKerasCustomLayer(self):
277+
"""
278+
Verifies that attempting to cluster a keras custom layer raises
279+
an exception.
280+
"""
281+
# If layer is not built, it has not weights, so
282+
# we just skip it.
283+
keras_custom_layer = self.keras_custom_layer
284+
cluster_wrapper.ClusterWeights(keras_custom_layer,
285+
**self.params)
286+
# We need to build weights before check that clustering is not supported.
287+
keras_custom_layer.build(input_shape=(10, 10))
288+
with self.assertRaises(ValueError):
289+
cluster_wrapper.ClusterWeights(keras_custom_layer,
290+
**self.params)
291+
255292
@keras_parameterized.run_all_keras_modes
256293
def testClusterSequentialModelSelectively(self):
257294
clustered_model = keras.Sequential()

tensorflow_model_optimization/python/core/keras/BUILD

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,23 @@ py_library(
4343
# python:variables tensorflow dep2,
4444
],
4545
)
46+
47+
py_library(
48+
name = "metrics",
49+
srcs = ["metrics.py"],
50+
srcs_version = "PY3",
51+
deps = [
52+
# python/eager:monitoring tensorflow dep2,
53+
],
54+
)
55+
56+
py_test(
57+
name = "metrics_test",
58+
srcs = ["metrics_test.py"],
59+
python_version = "PY3",
60+
deps = [
61+
":metrics",
62+
# mock dep1,
63+
# tensorflow dep1,
64+
],
65+
)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Implements monitoring."""
16+
17+
from tensorflow.python.eager import monitoring
18+
19+
20+
class MonitorBoolGauge():
21+
"""Monitoring utility class for usage metrics."""
22+
23+
_PRUNE_LOW_MAGNITUDE_USAGE = monitoring.BoolGauge(
24+
'/tfmot/api/sparsity/prune_low_magnitude',
25+
'prune_low_magnitude usage.', 'status')
26+
27+
_PRUNE_WRAPPER_USAGE = monitoring.BoolGauge(
28+
'/tfmot/api/sparsity/pruning_wrapper',
29+
'Pruning wrapper class usage.', 'layer')
30+
31+
_QUANTIZE_APPLY_USAGE = monitoring.BoolGauge(
32+
'/tfmot/api/quantization/quantize_apply',
33+
'quantize_apply usage.', 'status')
34+
35+
_QUANTIZE_WRAPPER_USAGE = monitoring.BoolGauge(
36+
'/tfmot/api/quantization/quantize_wrapper',
37+
'Quantize wrapper class usage.', 'layer')
38+
39+
_SUCCESS_LABEL = 'success'
40+
_FAILURE_LABEL = 'failure'
41+
42+
def __init__(self, name):
43+
self.bool_gauge = self.get_usage_gauge(name)
44+
45+
def get_usage_gauge(self, name):
46+
if name == 'prune_low_magnitude_usage':
47+
return MonitorBoolGauge._PRUNE_LOW_MAGNITUDE_USAGE
48+
if name == 'prune_low_magnitude_wrapper_usage':
49+
return MonitorBoolGauge._PRUNE_WRAPPER_USAGE
50+
if name == 'quantize_apply_usage':
51+
return MonitorBoolGauge._QUANTIZE_APPLY_USAGE
52+
if name == 'quantize_wrapper_usage':
53+
return MonitorBoolGauge._QUANTIZE_WRAPPER_USAGE
54+
raise ValueError('Invalid gauge name: {}'.format(name))
55+
56+
def __call__(self, func):
57+
def inner(*args, **kwargs):
58+
try:
59+
results = func(*args, **kwargs)
60+
self.bool_gauge.get_cell(MonitorBoolGauge._SUCCESS_LABEL).set(True)
61+
return results
62+
except Exception as error:
63+
self.bool_gauge.get_cell(MonitorBoolGauge._FAILURE_LABEL).set(True)
64+
raise error
65+
66+
if self.bool_gauge:
67+
return inner
68+
69+
return func
70+
71+
def set(self, label=None, value=True):
72+
"""Set the bool gauge to value if initialized.
73+
74+
Args:
75+
label: optional string label defaults to None.
76+
value: optional bool value defaults to True.
77+
"""
78+
if self.bool_gauge:
79+
self.bool_gauge.get_cell(label).set(value)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for Metrics."""
16+
17+
import mock
18+
import tensorflow as tf
19+
20+
from tensorflow.python.eager import monitoring
21+
from tensorflow_model_optimization.python.core.keras import metrics
22+
23+
24+
class MetricsTest(tf.test.TestCase):
25+
26+
gauge = monitoring.BoolGauge('/tfmot/metrics/testing', 'testing', 'labels')
27+
28+
def setUp(self):
29+
super(MetricsTest, self).setUp()
30+
self.test_label = tf.keras.layers.Conv2D(1, 1).__class__.__name__
31+
for label in [
32+
self.test_label, metrics.MonitorBoolGauge._SUCCESS_LABEL,
33+
metrics.MonitorBoolGauge._FAILURE_LABEL
34+
]:
35+
MetricsTest.gauge.get_cell(label).set(False)
36+
37+
with mock.patch.object(metrics.MonitorBoolGauge, 'get_usage_gauge',
38+
return_value=MetricsTest.gauge):
39+
self.monitor = metrics.MonitorBoolGauge('testing')
40+
41+
def test_DecoratorTest(self):
42+
@self.monitor
43+
def func(x):
44+
return x + 1
45+
46+
self.assertEqual(func(1), 2)
47+
self.assertTrue(MetricsTest.gauge.get_cell(
48+
metrics.MonitorBoolGauge._SUCCESS_LABEL).value())
49+
50+
def test_DecoratorFailureTest(self):
51+
@self.monitor
52+
def func(x):
53+
raise ValueError()
54+
55+
with self.assertRaises(ValueError):
56+
func(1)
57+
self.assertTrue(MetricsTest.gauge.get_cell(
58+
metrics.MonitorBoolGauge._FAILURE_LABEL).value())
59+
60+
def test_UndecoratedTest(self):
61+
with self.assertRaises(ValueError):
62+
@metrics.MonitorBoolGauge('unknown')
63+
def func(x):
64+
return x+1
65+
func(1)
66+
67+
def test_SetTest(self):
68+
self.monitor.set(self.test_label)
69+
self.assertTrue(MetricsTest.gauge.get_cell(self.test_label).value())
70+
71+
72+
if __name__ == '__main__':
73+
tf.test.main()

tensorflow_model_optimization/python/core/quantization/keras/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ py_library(
211211
":quantize_config",
212212
":quantizers",
213213
# tensorflow dep1,
214+
"//tensorflow_model_optimization/python/core/keras:metrics",
214215
"//tensorflow_model_optimization/python/core/keras:utils",
215216
],
216217
)
@@ -244,6 +245,7 @@ py_library(
244245
":quantize_layer",
245246
":quantize_wrapper",
246247
# tensorflow dep1,
248+
"//tensorflow_model_optimization/python/core/keras:metrics",
247249
"//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_registry",
248250
"//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_scheme",
249251
"//tensorflow_model_optimization/python/core/quantization/keras/layers:conv_batchnorm",

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_layout_transform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ def apply(self, model, layer_quantize_map):
6565
default_8bit_transforms.ConcatTransform4Inputs(),
6666
default_8bit_transforms.ConcatTransform3Inputs(),
6767
default_8bit_transforms.ConcatTransform(),
68-
default_8bit_transforms.AddReLUQuantize(),
69-
default_8bit_transforms.AddActivationQuantize(),
68+
default_8bit_transforms.LayerReLUQuantize(),
69+
default_8bit_transforms.LayerReluActivationQuantize(),
7070
]
7171
return model_transformer.ModelTransformer(
7272
model, transforms,

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -497,11 +497,12 @@ def replacement(self, match_layer):
497497
metadata=conv_metadata)
498498

499499

500-
class AddReLUQuantize(transforms.Transform):
500+
class LayerReLUQuantize(transforms.Transform):
501501
"""Ensure FQ does not get placed between Add and ReLU."""
502502

503503
def pattern(self):
504-
return LayerPattern('ReLU', inputs=[LayerPattern('Add')])
504+
return LayerPattern(
505+
'ReLU', inputs=[LayerPattern('Add|Conv2D|DepthwiseConv2D|Dense')])
505506

506507
def replacement(self, match_layer):
507508
relu_layer_node = match_layer
@@ -518,14 +519,14 @@ def custom_objects(self):
518519
}
519520

520521

521-
class AddActivationQuantize(AddReLUQuantize):
522+
class LayerReluActivationQuantize(LayerReLUQuantize):
522523
"""Ensure FQ does not get placed between Add and ReLU."""
523524

524525
def pattern(self):
525526
return LayerPattern(
526527
'Activation',
527528
config={'activation': 'relu'},
528-
inputs=[LayerPattern('Add')])
529+
inputs=[LayerPattern('Add|Conv2D|DepthwiseConv2D|Dense')])
529530

530531

531532
class InputLayerQuantize(transforms.Transform):

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms_test.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@
4545
# TODO(alanchiao): reduce redundancy by parameterizing on Depthwise vs Conv.
4646
class DefaultTransformsTest(tf.test.TestCase, parameterized.TestCase):
4747

48+
@classmethod
49+
def setUpClass(cls):
50+
super(DefaultTransformsTest, cls).setUpClass()
51+
np.random.seed(12345678)
52+
4853
def testTransformsConvBNReLUPattern(self):
4954
model = Conv2DModel.get_nonfolded_batchnorm_model(
5055
post_bn_activation=keras.layers.ReLU(6.0), model_type='functional')
@@ -344,8 +349,8 @@ def testSeparableConvQuantize_(self, kwargs):
344349
# Conv2DReshapeBatchNormActivationQuantize
345350

346351
@parameterized.parameters(
347-
('relu', default_8bit_transforms.AddReLUQuantize),
348-
('act_relu', default_8bit_transforms.AddActivationQuantize),
352+
('relu', default_8bit_transforms.LayerReLUQuantize),
353+
('act_relu', default_8bit_transforms.LayerReluActivationQuantize),
349354
)
350355
def testAddReLUQuantize(self, activation_type, transform_type):
351356
add = keras.layers.Add()
@@ -370,6 +375,33 @@ def testAddReLUQuantize(self, activation_type, transform_type):
370375
updated_metadata.get(add_layer.name).get('quantize_config'),
371376
default_8bit_quantize_configs.NoOpQuantizeConfig)
372377

378+
@parameterized.parameters(
379+
('relu', default_8bit_transforms.LayerReLUQuantize),
380+
('act_relu', default_8bit_transforms.LayerReluActivationQuantize))
381+
def testLayerReLUQuantize(self, activation_type, transform_type):
382+
# TODO(tfmot): Add tests for DepthConv and Dense
383+
input_shape = (1, 3, 3, 3)
384+
conv_layer = tf.keras.layers.Conv2D(5, 2, input_shape=input_shape)
385+
if activation_type == 'relu':
386+
act_layer = keras.layers.ReLU(6.0)
387+
elif activation_type == 'act_relu':
388+
act_layer = keras.layers.Activation('relu')
389+
390+
model = tf.keras.Sequential([conv_layer, act_layer])
391+
392+
transformed_model, updated_metadata = ModelTransformer(
393+
model,
394+
[transform_type()],
395+
).transform()
396+
397+
self.assertIsInstance(
398+
updated_metadata.get(model.layers[0].name).get('quantize_config'),
399+
default_8bit_quantize_configs.NoOpQuantizeConfig)
400+
401+
inputs = np.random.standard_normal(input_shape)
402+
self.assertAllClose(
403+
transformed_model.predict(inputs), model.predict(inputs))
404+
373405
def testAddsQuantizeLayerAfterInputLayer(self):
374406
inp1 = keras.layers.Input((3,))
375407
inp2 = keras.layers.Input((3,))

0 commit comments

Comments
 (0)