Skip to content

Commit 4b85dab

Browse files
authored
Add session hook for benchmark metric logging. (#3672)
* Add session hook for benchmark metric logging. Current hook is very similar as the LoggingTensorHook. Some of the function are directly copied since the original one was not exposed for import. We should seek to eventually move this code to core when it is mature enough. * Update metric_hook to use LoggingTensorHook as base. The existing hook is similar enough to LoggingTensorHook, and we should eliminate duplicate as much as possible. * Address review comment. 1. Update global step tensor handle. 2. Update tests. 3. Update document. * Update tests for py3. * Fix lint error
1 parent 83d827d commit 4b85dab

File tree

2 files changed

+338
-0
lines changed

2 files changed

+338
-0
lines changed
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Session hook for logging benchmark metric."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import tensorflow as tf
22+
23+
from official.utils.logging import logger
24+
25+
26+
class LoggingMetricHook(tf.train.LoggingTensorHook):
27+
"""Hook to log benchmark metric information.
28+
29+
This hook is very similar as tf.train.LoggingTensorHook, which logs given
30+
tensors every N local steps, every N seconds, or at the end. The metric
31+
information will be logged to given log_dir or via metric_logger in JSON
32+
format, which can be consumed by data analysis pipeline later.
33+
34+
Note that if `at_end` is True, `tensors` should not include any tensor
35+
whose evaluation produces a side effect such as consuming additional inputs.
36+
"""
37+
38+
def __init__(self, tensors, log_dir=None, metric_logger=None,
39+
every_n_iter=None, every_n_secs=None, at_end=False):
40+
"""Initializer for LoggingMetricHook.
41+
42+
Args:
43+
tensors: `dict` that maps string-valued tags to tensors/tensor names,
44+
or `iterable` of tensors/tensor names.
45+
log_dir: `string`, directory path that metric hook should write log to.
46+
metric_logger: instance of `BenchmarkLogger`, the benchmark logger that
47+
hook should use to write the log. Exactly one of the `log_dir` and
48+
`metric_logger` should be provided.
49+
every_n_iter: `int`, print the values of `tensors` once every N local
50+
steps taken on the current worker.
51+
every_n_secs: `int` or `float`, print the values of `tensors` once every N
52+
seconds. Exactly one of `every_n_iter` and `every_n_secs` should be
53+
provided.
54+
at_end: `bool` specifying whether to print the values of `tensors` at the
55+
end of the run.
56+
57+
Raises:
58+
ValueError:
59+
1. `every_n_iter` is non-positive, or
60+
2. Exactly one of every_n_iter and every_n_secs should be provided.
61+
3. Exactly one of log_dir and metric_logger should be provided.
62+
"""
63+
super(LoggingMetricHook, self).__init__(
64+
tensors=tensors,
65+
every_n_iter=every_n_iter,
66+
every_n_secs=every_n_secs,
67+
at_end=at_end)
68+
69+
if (log_dir is None) == (metric_logger is None):
70+
raise ValueError(
71+
"exactly one of log_dir and metric_logger should be provided.")
72+
73+
if log_dir is not None:
74+
self._logger = logger.BenchmarkLogger(log_dir)
75+
else:
76+
self._logger = metric_logger
77+
78+
def begin(self):
79+
super(LoggingMetricHook, self).begin()
80+
self._global_step_tensor = tf.train.get_global_step()
81+
if self._global_step_tensor is None:
82+
raise RuntimeError(
83+
"Global step should be created to use LoggingMetricHook.")
84+
if self._global_step_tensor.name not in self._current_tensors:
85+
self._current_tensors[self._global_step_tensor.name] = (
86+
self._global_step_tensor)
87+
88+
def after_run(self, unused_run_context, run_values):
89+
# should_trigger is a internal state that populated at before_run, and it is
90+
# using self_timer to determine whether it should trigger.
91+
if self._should_trigger:
92+
self._log_metric(run_values.results)
93+
94+
self._iter_count += 1
95+
96+
def end(self, session):
97+
if self._log_at_end:
98+
values = session.run(self._current_tensors)
99+
self._log_metric(values)
100+
101+
def _log_metric(self, tensor_values):
102+
self._timer.update_last_triggered_step(self._iter_count)
103+
global_step = tensor_values[self._global_step_tensor.name]
104+
# self._tag_order is populated during the init of LoggingTensorHook
105+
for tag in self._tag_order:
106+
self._logger.log_metric(tag, tensor_values[tag], global_step=global_step)
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for metric_hook."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import tempfile
22+
import time
23+
24+
import tensorflow as tf
25+
from tensorflow.python.training import monitored_session
26+
27+
from official.utils.logging import metric_hook
28+
29+
30+
class LoggingMetricHookTest(tf.test.TestCase):
31+
32+
def setUp(self):
33+
super(LoggingMetricHookTest, self).setUp()
34+
35+
class MockMetricLogger(object):
36+
def __init__(self):
37+
self.logged_metric = []
38+
39+
def log_metric(self, name, value, unit=None, global_step=None,
40+
extras=None):
41+
self.logged_metric.append({
42+
"name": name,
43+
"value": float(value),
44+
"unit": unit,
45+
"global_step": global_step,
46+
"extras": extras})
47+
48+
self._log_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
49+
self._logger = MockMetricLogger()
50+
51+
def tearDown(self):
52+
super(LoggingMetricHookTest, self).tearDown()
53+
tf.gfile.DeleteRecursively(self.get_temp_dir())
54+
55+
def test_illegal_args(self):
56+
with self.assertRaisesRegexp(ValueError, 'nvalid every_n_iter'):
57+
metric_hook.LoggingMetricHook(tensors=['t'], every_n_iter=0)
58+
with self.assertRaisesRegexp(ValueError, 'nvalid every_n_iter'):
59+
metric_hook.LoggingMetricHook(tensors=['t'], every_n_iter=-10)
60+
with self.assertRaisesRegexp(ValueError, 'xactly one of'):
61+
metric_hook.LoggingMetricHook(
62+
tensors=['t'], every_n_iter=5, every_n_secs=5)
63+
with self.assertRaisesRegexp(ValueError, 'xactly one of'):
64+
metric_hook.LoggingMetricHook(tensors=['t'])
65+
with self.assertRaisesRegexp(ValueError, 'log_dir and metric_logger'):
66+
metric_hook.LoggingMetricHook(tensors=['t'], every_n_iter=5)
67+
with self.assertRaisesRegexp(ValueError, 'log_dir and metric_logger'):
68+
metric_hook.LoggingMetricHook(
69+
tensors=['t'], every_n_iter=5, log_dir=self._log_dir,
70+
metric_logger=self._logger)
71+
72+
def test_print_at_end_only(self):
73+
with tf.Graph().as_default(), tf.Session() as sess:
74+
tf.train.get_or_create_global_step()
75+
t = tf.constant(42.0, name='foo')
76+
train_op = tf.constant(3)
77+
hook = metric_hook.LoggingMetricHook(
78+
tensors=[t.name], at_end=True, metric_logger=self._logger)
79+
hook.begin()
80+
mon_sess = monitored_session._HookedSession(sess, [hook])
81+
sess.run(tf.global_variables_initializer())
82+
83+
for _ in range(3):
84+
mon_sess.run(train_op)
85+
self.assertEqual(self._logger.logged_metric, [])
86+
87+
hook.end(sess)
88+
self.assertEqual(len(self._logger.logged_metric), 1)
89+
metric = self._logger.logged_metric[0]
90+
self.assertRegexpMatches(metric["name"], "foo")
91+
self.assertEqual(metric["value"], 42.0)
92+
self.assertEqual(metric["unit"], None)
93+
self.assertEqual(metric["global_step"], 0)
94+
95+
def test_global_step_not_found(self):
96+
with tf.Graph().as_default(), tf.Session() as sess:
97+
t = tf.constant(42.0, name='foo')
98+
hook = metric_hook.LoggingMetricHook(
99+
tensors=[t.name], at_end=True, metric_logger=self._logger)
100+
101+
with self.assertRaisesRegexp(
102+
RuntimeError, 'should be created to use LoggingMetricHook.'):
103+
hook.begin()
104+
105+
def test_log_tensors(self):
106+
with tf.Graph().as_default(), tf.Session() as sess:
107+
tf.train.get_or_create_global_step()
108+
t1 = tf.constant(42.0, name='foo')
109+
t2 = tf.constant(43.0, name='bar')
110+
train_op = tf.constant(3)
111+
hook = metric_hook.LoggingMetricHook(
112+
tensors=[t1, t2], at_end=True, metric_logger=self._logger)
113+
hook.begin()
114+
mon_sess = monitored_session._HookedSession(sess, [hook])
115+
sess.run(tf.global_variables_initializer())
116+
117+
for _ in range(3):
118+
mon_sess.run(train_op)
119+
self.assertEqual(self._logger.logged_metric, [])
120+
121+
hook.end(sess)
122+
self.assertEqual(len(self._logger.logged_metric), 2)
123+
metric1 = self._logger.logged_metric[0]
124+
self.assertRegexpMatches(str(metric1["name"]), "foo")
125+
self.assertEqual(metric1["value"], 42.0)
126+
self.assertEqual(metric1["unit"], None)
127+
self.assertEqual(metric1["global_step"], 0)
128+
129+
metric2 = self._logger.logged_metric[1]
130+
self.assertRegexpMatches(str(metric2["name"]), "bar")
131+
self.assertEqual(metric2["value"], 43.0)
132+
self.assertEqual(metric2["unit"], None)
133+
self.assertEqual(metric2["global_step"], 0)
134+
135+
def _validate_print_every_n_steps(self, sess, at_end):
136+
t = tf.constant(42.0, name='foo')
137+
138+
train_op = tf.constant(3)
139+
hook = metric_hook.LoggingMetricHook(
140+
tensors=[t.name], every_n_iter=10, at_end=at_end,
141+
metric_logger=self._logger)
142+
hook.begin()
143+
mon_sess = monitored_session._HookedSession(sess, [hook])
144+
sess.run(tf.global_variables_initializer())
145+
mon_sess.run(train_op)
146+
self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
147+
for _ in range(3):
148+
self._logger.logged_metric = []
149+
for _ in range(9):
150+
mon_sess.run(train_op)
151+
# assertNotRegexpMatches is not supported by python 3.1 and later
152+
self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
153+
mon_sess.run(train_op)
154+
self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
155+
156+
# Add additional run to verify proper reset when called multiple times.
157+
self._logger.logged_metric = []
158+
mon_sess.run(train_op)
159+
# assertNotRegexpMatches is not supported by python 3.1 and later
160+
self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
161+
162+
self._logger.logged_metric = []
163+
hook.end(sess)
164+
if at_end:
165+
self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
166+
else:
167+
# assertNotRegexpMatches is not supported by python 3.1 and later
168+
self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
169+
170+
def test_print_every_n_steps(self):
171+
with tf.Graph().as_default(), tf.Session() as sess:
172+
tf.train.get_or_create_global_step()
173+
self._validate_print_every_n_steps(sess, at_end=False)
174+
# Verify proper reset.
175+
self._validate_print_every_n_steps(sess, at_end=False)
176+
177+
def test_print_every_n_steps_and_end(self):
178+
with tf.Graph().as_default(), tf.Session() as sess:
179+
tf.train.get_or_create_global_step()
180+
self._validate_print_every_n_steps(sess, at_end=True)
181+
# Verify proper reset.
182+
self._validate_print_every_n_steps(sess, at_end=True)
183+
184+
def _validate_print_every_n_secs(self, sess, at_end):
185+
t = tf.constant(42.0, name='foo')
186+
train_op = tf.constant(3)
187+
188+
hook = metric_hook.LoggingMetricHook(
189+
tensors=[t.name], every_n_secs=1.0, at_end=at_end,
190+
metric_logger=self._logger)
191+
hook.begin()
192+
mon_sess = monitored_session._HookedSession(sess, [hook])
193+
sess.run(tf.global_variables_initializer())
194+
195+
mon_sess.run(train_op)
196+
self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
197+
198+
# assertNotRegexpMatches is not supported by python 3.1 and later
199+
self._logger.logged_metric = []
200+
mon_sess.run(train_op)
201+
self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
202+
time.sleep(1.0)
203+
204+
self._logger.logged_metric = []
205+
mon_sess.run(train_op)
206+
self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
207+
208+
self._logger.logged_metric = []
209+
hook.end(sess)
210+
if at_end:
211+
self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
212+
else:
213+
# assertNotRegexpMatches is not supported by python 3.1 and later
214+
self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
215+
216+
def test_print_every_n_secs(self):
217+
with tf.Graph().as_default(), tf.Session() as sess:
218+
tf.train.get_or_create_global_step()
219+
self._validate_print_every_n_secs(sess, at_end=False)
220+
# Verify proper reset.
221+
self._validate_print_every_n_secs(sess, at_end=False)
222+
223+
def test_print_every_n_secs_and_end(self):
224+
with tf.Graph().as_default(), tf.Session() as sess:
225+
tf.train.get_or_create_global_step()
226+
self._validate_print_every_n_secs(sess, at_end=True)
227+
# Verify proper reset.
228+
self._validate_print_every_n_secs(sess, at_end=True)
229+
230+
231+
if __name__ == '__main__':
232+
tf.test.main()

0 commit comments

Comments
 (0)