|
17 | 17 | Reference Paper: http://arxiv.org/pdf/1512.03385.pdf |
18 | 18 | """ |
19 | 19 |
|
| 20 | +import os |
20 | 21 | import random |
21 | 22 | from sklearn import metrics |
22 | 23 |
|
@@ -132,13 +133,24 @@ def res_net(x, y, activation=tf.nn.relu): |
132 | 133 | # Download and load MNIST data. |
133 | 134 | mnist = input_data.read_data_sets('MNIST_data') |
134 | 135 |
|
135 | | -# Train a resnet classifier |
136 | | -classifier = skflow.TensorFlowEstimator( |
137 | | - model_fn=res_net, n_classes=10, batch_size=100, steps=20000, |
138 | | - learning_rate=0.001) |
| 136 | +# Restore model if graph is saved into a folder. |
| 137 | +if os.path.exists("models/resnet/graph.pbtxt"): |
| 138 | + classifier = skflow.TensorFlowEstimator.restore("models/resnet/") |
| 139 | +else: |
| 140 | + # Create a new resnet classifier. |
| 141 | + classifier = skflow.TensorFlowEstimator( |
| 142 | + model_fn=res_net, n_classes=10, batch_size=100, steps=100, |
| 143 | + learning_rate=0.001) |
139 | 144 |
|
140 | | -classifier.fit(mnist.train.images, mnist.train.labels) |
| 145 | +while True: |
| 146 | + # Train model and save summaries into logdir. |
| 147 | + classifier.fit(mnist.train.images, mnist.train.labels, logdir="models/resnet/") |
| 148 | + |
| 149 | + # Calculate accuracy. |
| 150 | + score = metrics.accuracy_score( |
| 151 | + mnist.test.labels, classifier.predict(mnist.test.images, batch_size=64)) |
| 152 | + print('Accuracy: {0:f}'.format(score)) |
| 153 | + |
| 154 | + # Save model graph and checkpoints. |
| 155 | + classifier.save("models/resnet/") |
141 | 156 |
|
142 | | -# Calculate accuracy |
143 | | -score = metrics.accuracy_score(mnist.test.labels, classifier.predict(mnist.test.images)) |
144 | | -print('Accuracy: {0:f}'.format(score)) |
|
0 commit comments