Skip to content

Commit d41ed93

Browse files
authored
Fix hooks_test for examples/second hook (#4411)
* Fix hooks_test * Add more comments * Fix lints
1 parent 04c8187 commit d41ed93

File tree

1 file changed

+66
-55
lines changed

1 file changed

+66
-55
lines changed

official/utils/logs/hooks_test.py

Lines changed: 66 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,35 @@
2222
import time
2323

2424
import tensorflow as tf # pylint: disable=g-bad-import-order
25-
from tensorflow.python.training import monitored_session # pylint: disable=g-bad-import-order
2625

2726
from official.utils.logs import hooks
2827
from official.utils.testing import mock_lib
2928

30-
3129
tf.logging.set_verbosity(tf.logging.DEBUG)
3230

3331

3432
class ExamplesPerSecondHookTest(tf.test.TestCase):
35-
"""Tests for the ExamplesPerSecondHook."""
33+
"""Tests for the ExamplesPerSecondHook.
34+
35+
In this test, we explicitly run global_step tensor after train_op in order to
36+
grab the correct global step value. This is to correct for discrepancies in
37+
reported global step when running on GPUs. As in the after_run functions in
38+
ExamplesPerSecondHook, the global step from run_results
39+
(global_step = run_values.results) is not always correct and taken as the
40+
stale global_step (which may be 1 off the correct value). The exact
41+
global_step value should be from run_context
42+
(global_step = run_context.session.run(global_step_tensor)
43+
"""
3644

3745
def setUp(self):
3846
"""Mock out logging calls to verify if correct info is being monitored."""
3947
self._logger = mock_lib.MockBenchmarkLogger()
4048

4149
self.graph = tf.Graph()
4250
with self.graph.as_default():
43-
self.global_step = tf.train.get_or_create_global_step()
44-
self.train_op = tf.assign_add(self.global_step, 1)
51+
tf.train.create_global_step()
52+
self.train_op = tf.assign_add(tf.train.get_global_step(), 1)
53+
self.global_step = tf.train.get_global_step()
4554

4655
def test_raise_in_both_secs_and_steps(self):
4756
with self.assertRaises(ValueError):
@@ -59,86 +68,88 @@ def test_raise_in_none_secs_and_steps(self):
5968
every_n_secs=None,
6069
metric_logger=self._logger)
6170

62-
def _validate_log_every_n_steps(self, sess, every_n_steps, warm_steps):
71+
def _validate_log_every_n_steps(self, every_n_steps, warm_steps):
6372
hook = hooks.ExamplesPerSecondHook(
6473
batch_size=256,
6574
every_n_steps=every_n_steps,
6675
warm_steps=warm_steps,
6776
metric_logger=self._logger)
68-
hook.begin()
69-
mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access
70-
sess.run(tf.global_variables_initializer())
7177

72-
for _ in range(every_n_steps):
78+
with tf.train.MonitoredSession(
79+
tf.train.ChiefSessionCreator(), [hook]) as mon_sess:
80+
for _ in range(every_n_steps):
81+
# Explicitly run global_step after train_op to get the accurate
82+
# global_step value
83+
mon_sess.run(self.train_op)
84+
mon_sess.run(self.global_step)
85+
# Nothing should be in the list yet
86+
self.assertFalse(self._logger.logged_metric)
87+
7388
mon_sess.run(self.train_op)
74-
# Nothing should be in the list yet
75-
self.assertFalse(self._logger.logged_metric)
89+
global_step_val = mon_sess.run(self.global_step)
7690

77-
mon_sess.run(self.train_op)
78-
global_step_val = sess.run(self.global_step)
91+
if global_step_val > warm_steps:
92+
self._assert_metrics()
93+
else:
94+
# Nothing should be in the list yet
95+
self.assertFalse(self._logger.logged_metric)
7996

80-
if global_step_val > warm_steps:
81-
self._assert_metrics()
82-
else:
83-
# Nothing should be in the list yet
84-
self.assertFalse(self._logger.logged_metric)
85-
86-
# Add additional run to verify proper reset when called multiple times.
87-
prev_log_len = len(self._logger.logged_metric)
88-
mon_sess.run(self.train_op)
89-
global_step_val = sess.run(self.global_step)
90-
if every_n_steps == 1 and global_step_val > warm_steps:
91-
# Each time, we log two additional metrics. Did exactly 2 get added?
92-
self.assertEqual(len(self._logger.logged_metric), prev_log_len + 2)
93-
else:
94-
# No change in the size of the metric list.
95-
self.assertEqual(len(self._logger.logged_metric), prev_log_len)
97+
# Add additional run to verify proper reset when called multiple times.
98+
prev_log_len = len(self._logger.logged_metric)
99+
mon_sess.run(self.train_op)
100+
global_step_val = mon_sess.run(self.global_step)
96101

97-
hook.end(sess)
102+
if every_n_steps == 1 and global_step_val > warm_steps:
103+
# Each time, we log two additional metrics. Did exactly 2 get added?
104+
self.assertEqual(len(self._logger.logged_metric), prev_log_len + 2)
105+
else:
106+
# No change in the size of the metric list.
107+
self.assertEqual(len(self._logger.logged_metric), prev_log_len)
98108

99109
def test_examples_per_sec_every_1_steps(self):
100-
with self.graph.as_default(), tf.Session() as sess:
101-
self._validate_log_every_n_steps(sess, 1, 0)
110+
with self.graph.as_default():
111+
self._validate_log_every_n_steps(1, 0)
102112

103113
def test_examples_per_sec_every_5_steps(self):
104-
with self.graph.as_default(), tf.Session() as sess:
105-
self._validate_log_every_n_steps(sess, 5, 0)
114+
with self.graph.as_default():
115+
self._validate_log_every_n_steps(5, 0)
106116

107117
def test_examples_per_sec_every_1_steps_with_warm_steps(self):
108-
with self.graph.as_default(), tf.Session() as sess:
109-
self._validate_log_every_n_steps(sess, 1, 10)
118+
with self.graph.as_default():
119+
self._validate_log_every_n_steps(1, 10)
110120

111121
def test_examples_per_sec_every_5_steps_with_warm_steps(self):
112-
with self.graph.as_default(), tf.Session() as sess:
113-
self._validate_log_every_n_steps(sess, 5, 10)
122+
with self.graph.as_default():
123+
self._validate_log_every_n_steps(5, 10)
114124

115-
def _validate_log_every_n_secs(self, sess, every_n_secs):
125+
def _validate_log_every_n_secs(self, every_n_secs):
116126
hook = hooks.ExamplesPerSecondHook(
117127
batch_size=256,
118128
every_n_steps=None,
119129
every_n_secs=every_n_secs,
120130
metric_logger=self._logger)
121-
hook.begin()
122-
mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access
123-
sess.run(tf.global_variables_initializer())
124-
125-
mon_sess.run(self.train_op)
126-
# Nothing should be in the list yet
127-
self.assertFalse(self._logger.logged_metric)
128-
time.sleep(every_n_secs)
129131

130-
mon_sess.run(self.train_op)
131-
self._assert_metrics()
132+
with tf.train.MonitoredSession(
133+
tf.train.ChiefSessionCreator(), [hook]) as mon_sess:
134+
# Explicitly run global_step after train_op to get the accurate
135+
# global_step value
136+
mon_sess.run(self.train_op)
137+
mon_sess.run(self.global_step)
138+
# Nothing should be in the list yet
139+
self.assertFalse(self._logger.logged_metric)
140+
time.sleep(every_n_secs)
132141

133-
hook.end(sess)
142+
mon_sess.run(self.train_op)
143+
mon_sess.run(self.global_step)
144+
self._assert_metrics()
134145

135146
def test_examples_per_sec_every_1_secs(self):
136-
with self.graph.as_default(), tf.Session() as sess:
137-
self._validate_log_every_n_secs(sess, 1)
147+
with self.graph.as_default():
148+
self._validate_log_every_n_secs(1)
138149

139150
def test_examples_per_sec_every_5_secs(self):
140-
with self.graph.as_default(), tf.Session() as sess:
141-
self._validate_log_every_n_secs(sess, 5)
151+
with self.graph.as_default():
152+
self._validate_log_every_n_secs(5)
142153

143154
def _assert_metrics(self):
144155
metrics = self._logger.logged_metric

0 commit comments

Comments
 (0)