22
22
import time
23
23
24
24
import tensorflow as tf # pylint: disable=g-bad-import-order
25
- from tensorflow .python .training import monitored_session # pylint: disable=g-bad-import-order
26
25
27
26
from official .utils .logs import hooks
28
27
from official .utils .testing import mock_lib
29
28
30
-
31
29
tf .logging .set_verbosity (tf .logging .DEBUG )
32
30
33
31
34
32
class ExamplesPerSecondHookTest (tf .test .TestCase ):
35
- """Tests for the ExamplesPerSecondHook."""
33
+ """Tests for the ExamplesPerSecondHook.
34
+
35
+ In this test, we explicitly run global_step tensor after train_op in order to
36
+ grab the correct global step value. This is to correct for discrepancies in
37
+ reported global step when running on GPUs. As in the after_run functions in
38
+ ExamplesPerSecondHook, the global step from run_results
39
+ (global_step = run_values.results) is not always correct and taken as the
40
+ stale global_step (which may be 1 off the correct value). The exact
41
+ global_step value should be from run_context
42
+ (global_step = run_context.session.run(global_step_tensor)
43
+ """
36
44
37
45
def setUp (self ):
38
46
"""Mock out logging calls to verify if correct info is being monitored."""
39
47
self ._logger = mock_lib .MockBenchmarkLogger ()
40
48
41
49
self .graph = tf .Graph ()
42
50
with self .graph .as_default ():
43
- self .global_step = tf .train .get_or_create_global_step ()
44
- self .train_op = tf .assign_add (self .global_step , 1 )
51
+ tf .train .create_global_step ()
52
+ self .train_op = tf .assign_add (tf .train .get_global_step (), 1 )
53
+ self .global_step = tf .train .get_global_step ()
45
54
46
55
def test_raise_in_both_secs_and_steps (self ):
47
56
with self .assertRaises (ValueError ):
@@ -59,86 +68,88 @@ def test_raise_in_none_secs_and_steps(self):
59
68
every_n_secs = None ,
60
69
metric_logger = self ._logger )
61
70
62
- def _validate_log_every_n_steps (self , sess , every_n_steps , warm_steps ):
71
+ def _validate_log_every_n_steps (self , every_n_steps , warm_steps ):
63
72
hook = hooks .ExamplesPerSecondHook (
64
73
batch_size = 256 ,
65
74
every_n_steps = every_n_steps ,
66
75
warm_steps = warm_steps ,
67
76
metric_logger = self ._logger )
68
- hook .begin ()
69
- mon_sess = monitored_session ._HookedSession (sess , [hook ]) # pylint: disable=protected-access
70
- sess .run (tf .global_variables_initializer ())
71
77
72
- for _ in range (every_n_steps ):
78
+ with tf .train .MonitoredSession (
79
+ tf .train .ChiefSessionCreator (), [hook ]) as mon_sess :
80
+ for _ in range (every_n_steps ):
81
+ # Explicitly run global_step after train_op to get the accurate
82
+ # global_step value
83
+ mon_sess .run (self .train_op )
84
+ mon_sess .run (self .global_step )
85
+ # Nothing should be in the list yet
86
+ self .assertFalse (self ._logger .logged_metric )
87
+
73
88
mon_sess .run (self .train_op )
74
- # Nothing should be in the list yet
75
- self .assertFalse (self ._logger .logged_metric )
89
+ global_step_val = mon_sess .run (self .global_step )
76
90
77
- mon_sess .run (self .train_op )
78
- global_step_val = sess .run (self .global_step )
91
+ if global_step_val > warm_steps :
92
+ self ._assert_metrics ()
93
+ else :
94
+ # Nothing should be in the list yet
95
+ self .assertFalse (self ._logger .logged_metric )
79
96
80
- if global_step_val > warm_steps :
81
- self ._assert_metrics ()
82
- else :
83
- # Nothing should be in the list yet
84
- self .assertFalse (self ._logger .logged_metric )
85
-
86
- # Add additional run to verify proper reset when called multiple times.
87
- prev_log_len = len (self ._logger .logged_metric )
88
- mon_sess .run (self .train_op )
89
- global_step_val = sess .run (self .global_step )
90
- if every_n_steps == 1 and global_step_val > warm_steps :
91
- # Each time, we log two additional metrics. Did exactly 2 get added?
92
- self .assertEqual (len (self ._logger .logged_metric ), prev_log_len + 2 )
93
- else :
94
- # No change in the size of the metric list.
95
- self .assertEqual (len (self ._logger .logged_metric ), prev_log_len )
97
+ # Add additional run to verify proper reset when called multiple times.
98
+ prev_log_len = len (self ._logger .logged_metric )
99
+ mon_sess .run (self .train_op )
100
+ global_step_val = mon_sess .run (self .global_step )
96
101
97
- hook .end (sess )
102
+ if every_n_steps == 1 and global_step_val > warm_steps :
103
+ # Each time, we log two additional metrics. Did exactly 2 get added?
104
+ self .assertEqual (len (self ._logger .logged_metric ), prev_log_len + 2 )
105
+ else :
106
+ # No change in the size of the metric list.
107
+ self .assertEqual (len (self ._logger .logged_metric ), prev_log_len )
98
108
99
109
def test_examples_per_sec_every_1_steps (self ):
100
- with self .graph .as_default (), tf . Session () as sess :
101
- self ._validate_log_every_n_steps (sess , 1 , 0 )
110
+ with self .graph .as_default ():
111
+ self ._validate_log_every_n_steps (1 , 0 )
102
112
103
113
def test_examples_per_sec_every_5_steps (self ):
104
- with self .graph .as_default (), tf . Session () as sess :
105
- self ._validate_log_every_n_steps (sess , 5 , 0 )
114
+ with self .graph .as_default ():
115
+ self ._validate_log_every_n_steps (5 , 0 )
106
116
107
117
def test_examples_per_sec_every_1_steps_with_warm_steps (self ):
108
- with self .graph .as_default (), tf . Session () as sess :
109
- self ._validate_log_every_n_steps (sess , 1 , 10 )
118
+ with self .graph .as_default ():
119
+ self ._validate_log_every_n_steps (1 , 10 )
110
120
111
121
def test_examples_per_sec_every_5_steps_with_warm_steps (self ):
112
- with self .graph .as_default (), tf . Session () as sess :
113
- self ._validate_log_every_n_steps (sess , 5 , 10 )
122
+ with self .graph .as_default ():
123
+ self ._validate_log_every_n_steps (5 , 10 )
114
124
115
- def _validate_log_every_n_secs (self , sess , every_n_secs ):
125
+ def _validate_log_every_n_secs (self , every_n_secs ):
116
126
hook = hooks .ExamplesPerSecondHook (
117
127
batch_size = 256 ,
118
128
every_n_steps = None ,
119
129
every_n_secs = every_n_secs ,
120
130
metric_logger = self ._logger )
121
- hook .begin ()
122
- mon_sess = monitored_session ._HookedSession (sess , [hook ]) # pylint: disable=protected-access
123
- sess .run (tf .global_variables_initializer ())
124
-
125
- mon_sess .run (self .train_op )
126
- # Nothing should be in the list yet
127
- self .assertFalse (self ._logger .logged_metric )
128
- time .sleep (every_n_secs )
129
131
130
- mon_sess .run (self .train_op )
131
- self ._assert_metrics ()
132
+ with tf .train .MonitoredSession (
133
+ tf .train .ChiefSessionCreator (), [hook ]) as mon_sess :
134
+ # Explicitly run global_step after train_op to get the accurate
135
+ # global_step value
136
+ mon_sess .run (self .train_op )
137
+ mon_sess .run (self .global_step )
138
+ # Nothing should be in the list yet
139
+ self .assertFalse (self ._logger .logged_metric )
140
+ time .sleep (every_n_secs )
132
141
133
- hook .end (sess )
142
+ mon_sess .run (self .train_op )
143
+ mon_sess .run (self .global_step )
144
+ self ._assert_metrics ()
134
145
135
146
def test_examples_per_sec_every_1_secs (self ):
136
- with self .graph .as_default (), tf . Session () as sess :
137
- self ._validate_log_every_n_secs (sess , 1 )
147
+ with self .graph .as_default ():
148
+ self ._validate_log_every_n_secs (1 )
138
149
139
150
def test_examples_per_sec_every_5_secs (self ):
140
- with self .graph .as_default (), tf . Session () as sess :
141
- self ._validate_log_every_n_secs (sess , 5 )
151
+ with self .graph .as_default ():
152
+ self ._validate_log_every_n_secs (5 )
142
153
143
154
def _assert_metrics (self ):
144
155
metrics = self ._logger .logged_metric
0 commit comments