Skip to content
This repository was archived by the owner on Aug 31, 2021. It is now read-only.

Commit d0ac8de

Browse files
committed
Updating resnet to save logdir, graph and checkpoints and run batches in prediction.
1 parent 0c57c03 commit d0ac8de

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

examples/resnet.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
Reference Paper: http://arxiv.org/pdf/1512.03385.pdf
1818
"""
1919

20+
import os
2021
import random
2122
from sklearn import metrics
2223

@@ -132,13 +133,24 @@ def res_net(x, y, activation=tf.nn.relu):
132133
# Download and load MNIST data.
133134
mnist = input_data.read_data_sets('MNIST_data')
134135

135-
# Train a resnet classifier
136-
classifier = skflow.TensorFlowEstimator(
137-
model_fn=res_net, n_classes=10, batch_size=100, steps=20000,
138-
learning_rate=0.001)
136+
# Restore model if graph is saved into a folder.
137+
if os.path.exists("models/resnet/graph.pbtxt"):
138+
classifier = skflow.TensorFlowEstimator.restore("models/resnet/")
139+
else:
140+
# Create a new resnet classifier.
141+
classifier = skflow.TensorFlowEstimator(
142+
model_fn=res_net, n_classes=10, batch_size=100, steps=100,
143+
learning_rate=0.001)
139144

140-
classifier.fit(mnist.train.images, mnist.train.labels)
145+
while True:
146+
# Train model and save summaries into logdir.
147+
classifier.fit(mnist.train.images, mnist.train.labels, logdir="models/resnet/")
148+
149+
# Calculate accuracy.
150+
score = metrics.accuracy_score(
151+
mnist.test.labels, classifier.predict(mnist.test.images, batch_size=64))
152+
print('Accuracy: {0:f}'.format(score))
153+
154+
# Save model graph and checkpoints.
155+
classifier.save("models/resnet/")
141156

142-
# Calculate accuracy
143-
score = metrics.accuracy_score(mnist.test.labels, classifier.predict(mnist.test.images))
144-
print('Accuracy: {0:f}'.format(score))

0 commit comments

Comments
 (0)