Skip to content

Commit c27c9ca

Browse files
daverimtensorflower-gardener
authored andcommitted
Add monitoring for API usage
PiperOrigin-RevId: 365508776
1 parent 3c4f3b2 commit c27c9ca

File tree

10 files changed

+191
-0
lines changed

10 files changed

+191
-0
lines changed

tensorflow_model_optimization/python/core/keras/BUILD

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,23 @@ py_library(
4343
# python:variables tensorflow dep2,
4444
],
4545
)
46+
47+
py_library(
48+
name = "metrics",
49+
srcs = ["metrics.py"],
50+
srcs_version = "PY3",
51+
deps = [
52+
# python/eager:monitoring tensorflow dep2,
53+
],
54+
)
55+
56+
py_test(
57+
name = "metrics_test",
58+
srcs = ["metrics_test.py"],
59+
python_version = "PY3",
60+
deps = [
61+
":metrics",
62+
# mock dep1,
63+
# tensorflow dep1,
64+
],
65+
)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Implements monitoring."""
16+
17+
from tensorflow.python.eager import monitoring
18+
19+
20+
class MonitorBoolGauge():
21+
"""Monitoring utility class for usage metrics."""
22+
23+
_PRUNE_LOW_MAGNITUDE_USAGE = monitoring.BoolGauge(
24+
'/tfmot/api/sparsity/prune_low_magnitude',
25+
'prune_low_magnitude usage.', 'status')
26+
27+
_PRUNE_WRAPPER_USAGE = monitoring.BoolGauge(
28+
'/tfmot/api/sparsity/pruning_wrapper',
29+
'Pruning wrapper class usage.', 'layer')
30+
31+
_QUANTIZE_APPLY_USAGE = monitoring.BoolGauge(
32+
'/tfmot/api/quantization/quantize_apply',
33+
'quantize_apply usage.', 'status')
34+
35+
_QUANTIZE_WRAPPER_USAGE = monitoring.BoolGauge(
36+
'/tfmot/api/quantization/quantize_wrapper',
37+
'Quantize wrapper class usage.', 'layer')
38+
39+
_SUCCESS_LABEL = 'success'
40+
_FAILURE_LABEL = 'failure'
41+
42+
def __init__(self, name):
43+
self.bool_gauge = self.get_usage_gauge(name)
44+
45+
def get_usage_gauge(self, name):
46+
if name == 'prune_low_magnitude_usage':
47+
return MonitorBoolGauge._PRUNE_LOW_MAGNITUDE_USAGE
48+
if name == 'prune_low_magnitude_wrapper_usage':
49+
return MonitorBoolGauge._PRUNE_WRAPPER_USAGE
50+
if name == 'quantize_apply_usage':
51+
return MonitorBoolGauge._QUANTIZE_APPLY_USAGE
52+
if name == 'quantize_wrapper_usage':
53+
return MonitorBoolGauge._QUANTIZE_WRAPPER_USAGE
54+
raise ValueError('Invalid gauge name: {}'.format(name))
55+
56+
def __call__(self, func):
57+
def inner(*args, **kwargs):
58+
try:
59+
results = func(*args, **kwargs)
60+
self.bool_gauge.get_cell(MonitorBoolGauge._SUCCESS_LABEL).set(True)
61+
return results
62+
except Exception as error:
63+
self.bool_gauge.get_cell(MonitorBoolGauge._FAILURE_LABEL).set(True)
64+
raise error
65+
66+
if self.bool_gauge:
67+
return inner
68+
69+
return func
70+
71+
def set(self, label=None, value=True):
72+
"""Set the bool gauge to value if initialized.
73+
74+
Args:
75+
label: optional string label defaults to None.
76+
value: optional bool value defaults to True.
77+
"""
78+
if self.bool_gauge:
79+
self.bool_gauge.get_cell(label).set(value)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for Metrics."""
16+
17+
import mock
18+
import tensorflow as tf
19+
20+
from tensorflow.python.eager import monitoring
21+
from tensorflow_model_optimization.python.core.keras import metrics
22+
23+
24+
class MetricsTest(tf.test.TestCase):
25+
26+
gauge = monitoring.BoolGauge('/tfmot/metrics/testing', 'testing', 'labels')
27+
28+
def setUp(self):
29+
super(MetricsTest, self).setUp()
30+
self.test_label = tf.keras.layers.Conv2D(1, 1).__class__.__name__
31+
for label in [
32+
self.test_label, metrics.MonitorBoolGauge._SUCCESS_LABEL,
33+
metrics.MonitorBoolGauge._FAILURE_LABEL
34+
]:
35+
MetricsTest.gauge.get_cell(label).set(False)
36+
37+
with mock.patch.object(metrics.MonitorBoolGauge, 'get_usage_gauge',
38+
return_value=MetricsTest.gauge):
39+
self.monitor = metrics.MonitorBoolGauge('testing')
40+
41+
def test_DecoratorTest(self):
42+
@self.monitor
43+
def func(x):
44+
return x + 1
45+
46+
self.assertEqual(func(1), 2)
47+
self.assertTrue(MetricsTest.gauge.get_cell(
48+
metrics.MonitorBoolGauge._SUCCESS_LABEL).value())
49+
50+
def test_DecoratorFailureTest(self):
51+
@self.monitor
52+
def func(x):
53+
raise ValueError()
54+
55+
with self.assertRaises(ValueError):
56+
func(1)
57+
self.assertTrue(MetricsTest.gauge.get_cell(
58+
metrics.MonitorBoolGauge._FAILURE_LABEL).value())
59+
60+
def test_UndecoratedTest(self):
61+
with self.assertRaises(ValueError):
62+
@metrics.MonitorBoolGauge('unknown')
63+
def func(x):
64+
return x+1
65+
func(1)
66+
67+
def test_SetTest(self):
68+
self.monitor.set(self.test_label)
69+
self.assertTrue(MetricsTest.gauge.get_cell(self.test_label).value())
70+
71+
72+
if __name__ == '__main__':
73+
tf.test.main()

tensorflow_model_optimization/python/core/quantization/keras/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ py_library(
211211
":quantize_config",
212212
":quantizers",
213213
# tensorflow dep1,
214+
"//tensorflow_model_optimization/python/core/keras:metrics",
214215
"//tensorflow_model_optimization/python/core/keras:utils",
215216
],
216217
)
@@ -244,6 +245,7 @@ py_library(
244245
":quantize_layer",
245246
":quantize_wrapper",
246247
# tensorflow dep1,
248+
"//tensorflow_model_optimization/python/core/keras:metrics",
247249
"//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_registry",
248250
"//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_scheme",
249251
"//tensorflow_model_optimization/python/core/quantization/keras/layers:conv_batchnorm",

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@
4545
# TODO(alanchiao): reduce redundancy by parameterizing on Depthwise vs Conv.
4646
class DefaultTransformsTest(tf.test.TestCase, parameterized.TestCase):
4747

48+
@classmethod
49+
def setUpClass(cls):
50+
super(DefaultTransformsTest, cls).setUpClass()
51+
np.random.seed(12345678)
52+
4853
def testTransformsConvBNReLUPattern(self):
4954
model = Conv2DModel.get_nonfolded_batchnorm_model(
5055
post_bn_activation=keras.layers.ReLU(6.0), model_type='functional')

tensorflow_model_optimization/python/core/quantization/keras/quantize.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import tensorflow as tf
1818

19+
from tensorflow_model_optimization.python.core.keras import metrics
1920
from tensorflow_model_optimization.python.core.quantization.keras import quantize_annotate as quantize_annotate_mod
2021
from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
2122
from tensorflow_model_optimization.python.core.quantization.keras import quantize_config as quantize_config_mod
@@ -263,6 +264,7 @@ def quantize_annotate_layer(to_annotate, quantize_config=None):
263264
layer=to_annotate, quantize_config=quantize_config)
264265

265266

267+
@metrics.MonitorBoolGauge('quantize_apply_usage')
266268
def quantize_apply(
267269
model,
268270
scheme=default_8bit_quantize_scheme.Default8BitQuantizeScheme()):

tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
from tensorflow.python.util import tf_inspect
3232

33+
from tensorflow_model_optimization.python.core.keras import metrics
3334
from tensorflow_model_optimization.python.core.keras import utils
3435
from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
3536

@@ -70,6 +71,8 @@ def __init__(self, layer, quantize_config, **kwargs):
7071
self.quantize_config = quantize_config
7172

7273
self._track_trackable(layer, name='layer')
74+
metrics.MonitorBoolGauge('quantize_wrapper_usage').set(
75+
layer.__class__.__name__)
7376

7477
@staticmethod
7578
def _make_layer_name(layer):

tensorflow_model_optimization/python/core/sparsity/keras/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ py_library(
2727
":pruning_schedule",
2828
":pruning_wrapper",
2929
# tensorflow dep1,
30+
"//tensorflow_model_optimization/python/core/keras:metrics",
3031
],
3132
)
3233

@@ -92,6 +93,7 @@ py_library(
9293
# tensorflow dep1,
9394
# python/keras/utils:generic_utils tensorflow dep2,
9495
"//tensorflow_model_optimization/python/core/keras:compat",
96+
"//tensorflow_model_optimization/python/core/keras:metrics",
9597
"//tensorflow_model_optimization/python/core/keras:utils",
9698
],
9799
)

tensorflow_model_optimization/python/core/sparsity/keras/prune.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import tensorflow as tf
1919

20+
from tensorflow_model_optimization.python.core.keras import metrics
2021
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule as pruning_sched
2122
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
2223

@@ -52,6 +53,7 @@ def prune_scope():
5253
{'PruneLowMagnitude': pruning_wrapper.PruneLowMagnitude})
5354

5455

56+
@metrics.MonitorBoolGauge('prune_low_magnitude_usage')
5557
def prune_low_magnitude(to_prune,
5658
pruning_schedule=pruning_sched.ConstantSparsity(0.5, 0),
5759
block_size=(1, 1),

tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
# b/(139939526): update to use public API.
2828
from tensorflow.python.keras.utils import generic_utils
2929
from tensorflow_model_optimization.python.core.keras import compat as tf_compat
30+
from tensorflow_model_optimization.python.core.keras import metrics
3031
from tensorflow_model_optimization.python.core.keras import utils
3132
from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer
3233
from tensorflow_model_optimization.python.core.sparsity.keras import prune_registry
@@ -185,6 +186,8 @@ def __init__(self,
185186
if not hasattr(self, '_batch_input_shape') and hasattr(
186187
layer, '_batch_input_shape'):
187188
self._batch_input_shape = self.layer._batch_input_shape
189+
metrics.MonitorBoolGauge('prune_low_magnitude_wrapper_usage').set(
190+
layer.__class__.__name__)
188191

189192
def build(self, input_shape):
190193
super(PruneLowMagnitude, self).build(input_shape)

0 commit comments

Comments
 (0)