@@ -58,19 +58,21 @@ def compute_accuracy(logits, labels):
58
58
59
59
def train (model , optimizer , dataset , step_counter , log_interval = None ):
60
60
"""Trains model on `dataset` using `optimizer`."""
61
+ from tensorflow .contrib import summary as contrib_summary # pylint: disable=g-import-not-at-top
61
62
62
63
start = time .time ()
63
64
for (batch , (images , labels )) in enumerate (dataset ):
64
- with tf . contrib . summary .record_summaries_every_n_global_steps (
65
+ with contrib_summary .record_summaries_every_n_global_steps (
65
66
10 , global_step = step_counter ):
66
67
# Record the operations used to compute the loss given the input,
67
68
# so that the gradient of the loss with respect to the variables
68
69
# can be computed.
69
70
with tf .GradientTape () as tape :
70
71
logits = model (images , training = True )
71
72
loss_value = loss (logits , labels )
72
- tf .contrib .summary .scalar ('loss' , loss_value )
73
- tf .contrib .summary .scalar ('accuracy' , compute_accuracy (logits , labels ))
73
+ contrib_summary .scalar ('loss' , loss_value )
74
+ contrib_summary .scalar ('accuracy' ,
75
+ compute_accuracy (logits , labels ))
74
76
grads = tape .gradient (loss_value , model .variables )
75
77
optimizer .apply_gradients (
76
78
zip (grads , model .variables ), global_step = step_counter )
@@ -82,6 +84,7 @@ def train(model, optimizer, dataset, step_counter, log_interval=None):
82
84
83
85
def test (model , dataset ):
84
86
"""Perform an evaluation of `model` on the examples from `dataset`."""
87
+ from tensorflow .contrib import summary as contrib_summary # pylint: disable=g-import-not-at-top
85
88
avg_loss = tf .keras .metrics .Mean ('loss' , dtype = tf .float32 )
86
89
accuracy = tf .keras .metrics .Accuracy ('accuracy' , dtype = tf .float32 )
87
90
@@ -93,9 +96,9 @@ def test(model, dataset):
93
96
tf .cast (labels , tf .int64 ))
94
97
print ('Test set: Average loss: %.4f, Accuracy: %4f%%\n ' %
95
98
(avg_loss .result (), 100 * accuracy .result ()))
96
- with tf . contrib . summary .always_record_summaries ():
97
- tf . contrib . summary .scalar ('loss' , avg_loss .result ())
98
- tf . contrib . summary .scalar ('accuracy' , accuracy .result ())
99
+ with contrib_summary .always_record_summaries ():
100
+ contrib_summary .scalar ('loss' , avg_loss .result ())
101
+ contrib_summary .scalar ('accuracy' , accuracy .result ())
99
102
100
103
101
104
def run_mnist_eager (flags_obj ):
@@ -137,9 +140,9 @@ def run_mnist_eager(flags_obj):
137
140
else :
138
141
train_dir = None
139
142
test_dir = None
140
- summary_writer = tf .contrib .summary .create_file_writer (
143
+ summary_writer = tf .compat . v2 .summary .create_file_writer (
141
144
train_dir , flush_millis = 10000 )
142
- test_summary_writer = tf .contrib .summary .create_file_writer (
145
+ test_summary_writer = tf .compat . v2 .summary .create_file_writer (
143
146
test_dir , flush_millis = 10000 , name = 'test' )
144
147
145
148
# Create and restore checkpoint (if one exists on the path)
0 commit comments