Skip to content

Commit dfebf9a

Browse files
committed
float16 example
1 parent c99a6a2 commit dfebf9a

File tree

1 file changed

+104
-0
lines changed

1 file changed

+104
-0
lines changed

example/tutorial_mnist_float16.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
#! /usr/bin/python
2+
# -*- coding: utf8 -*-
3+
4+
import numpy as np
5+
import tensorflow as tf
6+
import tensorlayer as tl
7+
from tensorlayer.layers import *
8+
import time
9+
10+
X_train, y_train, X_val, y_val, X_test, y_test = \
11+
tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1))
12+
13+
sess = tf.InteractiveSession()
14+
15+
batch_size = 128
16+
D_TYPE = tf.float32 # tf.float32 tf.float16
17+
18+
x = tf.placeholder(D_TYPE, shape=[batch_size, 28, 28, 1])
19+
y_ = tf.placeholder(tf.int64, shape=[batch_size,])
20+
21+
def model(x, is_train=True, reuse=False):
22+
with tf.variable_scope("model", reuse=reuse):
23+
tl.layers.set_name_reuse(reuse)
24+
n = InputLayer(x, name='input')
25+
# cnn
26+
n = Conv2d(n, 32, (5, 5), (1, 1), padding='SAME',
27+
W_init_args={'dtype': D_TYPE}, b_init_args={'dtype': D_TYPE}, name='cnn1')
28+
n = BatchNormLayer(n, act=tf.nn.relu, is_train=is_train, dtype=D_TYPE, name='bn1')
29+
n = MaxPool2d(n, (2, 2), (2, 2), padding='SAME', name='pool1')
30+
n = Conv2d(n, 64, (5, 5), (1, 1), padding='SAME',
31+
W_init_args={'dtype': D_TYPE}, b_init_args={'dtype': D_TYPE}, name='cnn2')
32+
n = BatchNormLayer(n, act=tf.nn.relu, is_train=is_train, dtype=D_TYPE, name='bn2')
33+
n = MaxPool2d(n, (2, 2), (2, 2), padding='SAME', name='pool2')
34+
# mlp
35+
n = FlattenLayer(n, name='flatten')
36+
n = DropoutLayer(n, 0.5, True, is_train, name='drop1')
37+
n = DenseLayer(n, 256, act=tf.nn.relu,
38+
W_init_args={'dtype': D_TYPE}, b_init_args={'dtype': D_TYPE}, name='relu1')
39+
n = DropoutLayer(n, 0.5, True, is_train, name='drop2')
40+
n = DenseLayer(n, 10, act=tf.identity,
41+
W_init_args={'dtype': D_TYPE}, b_init_args={'dtype': D_TYPE}, name='output')
42+
return n
43+
44+
# define inferences
45+
net_train = model(x, is_train=True, reuse=False)
46+
net_test = model(x, is_train=False, reuse=True)
47+
48+
net_train.print_params(False)
49+
50+
# cost for training
51+
y = net_train.outputs
52+
cost = tl.cost.cross_entropy(y, y_, name='xentropy')
53+
54+
# cost and accuracy for evalution
55+
y2 = net_test.outputs
56+
cost_test = tl.cost.cross_entropy(y2, y_, name='xentropy2')
57+
correct_prediction = tf.equal(tf.argmax(y2, 1), y_)
58+
acc = tf.reduce_mean(tf.cast(correct_prediction, D_TYPE))
59+
60+
# define the optimizer
61+
train_params = tl.layers.get_variables_with_name('model', train_only=True, printable=False)
62+
train_op = tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.9, beta2=0.999,
63+
# epsilon=1e-08, # for float32 as default
64+
epsilon=1e-4, # for float16, see https://stackoverflow.com/questions/42064941/tensorflow-float16-support-is-broken
65+
use_locking=False).minimize(cost, var_list=train_params)
66+
67+
# initialize all variables in the session
68+
tl.layers.initialize_global_variables(sess)
69+
70+
# train the network
71+
n_epoch = 500
72+
print_freq = 1
73+
74+
for epoch in range(n_epoch):
75+
start_time = time.time()
76+
for X_train_a, y_train_a in tl.iterate.minibatches(
77+
X_train, y_train, batch_size, shuffle=True):
78+
sess.run(train_op, feed_dict={x: X_train_a, y_: y_train_a})
79+
80+
if epoch + 1 == 1 or (epoch + 1) % print_freq == 0:
81+
print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time))
82+
train_loss, train_acc, n_batch = 0, 0, 0
83+
for X_train_a, y_train_a in tl.iterate.minibatches(
84+
X_train, y_train, batch_size, shuffle=True):
85+
err, ac = sess.run([cost_test, acc], feed_dict={x: X_train_a, y_: y_train_a})
86+
train_loss += err; train_acc += ac; n_batch += 1
87+
print(" train loss: %f" % (train_loss/ n_batch))
88+
print(" train acc: %f" % (train_acc/ n_batch))
89+
val_loss, val_acc, n_batch = 0, 0, 0
90+
for X_val_a, y_val_a in tl.iterate.minibatches(
91+
X_val, y_val, batch_size, shuffle=True):
92+
err, ac = sess.run([cost_test, acc], feed_dict={x: X_val_a, y_: y_val_a})
93+
val_loss += err; val_acc += ac; n_batch += 1
94+
print(" val loss: %f" % (val_loss/ n_batch))
95+
print(" val acc: %f" % (val_acc/ n_batch))
96+
97+
print('Evaluation')
98+
test_loss, test_acc, n_batch = 0, 0, 0
99+
for X_test_a, y_test_a in tl.iterate.minibatches(
100+
X_test, y_test, batch_size, shuffle=True):
101+
err, ac = sess.run([cost_test, acc], feed_dict={x: X_test_a, y_: y_test_a})
102+
test_loss += err; test_acc += ac; n_batch += 1
103+
print(" test loss: %f" % (test_loss/n_batch))
104+
print(" test acc: %f" % (test_acc/n_batch))

0 commit comments

Comments
 (0)