|
| 1 | +#! /usr/bin/python |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | + |
| 4 | + |
| 5 | + |
| 6 | +import tensorflow as tf |
| 7 | +import tensorlayer as tl |
| 8 | + |
| 9 | +sess = tf.InteractiveSession() |
| 10 | + |
| 11 | +# prepare data |
| 12 | +X_train, y_train, X_val, y_val, X_test, y_test = \ |
| 13 | + tl.files.load_mnist_dataset(shape=(-1,784)) |
| 14 | +# define placeholder |
| 15 | +x = tf.placeholder(tf.float32, shape=[None, 784], name='x') |
| 16 | +y_ = tf.placeholder(tf.int64, shape=[None, ], name='y_') |
| 17 | + |
| 18 | +# define the network |
| 19 | +network = tl.layers.InputLayer(x, name='input') |
| 20 | +network = tl.layers.DropoutLayer(network, keep=0.8, name='drop1') |
| 21 | +network = tl.layers.DenseLayer(network, 800, tf.nn.relu, name='relu1') |
| 22 | +network = tl.layers.DropoutLayer(network, keep=0.5, name='drop2') |
| 23 | +network = tl.layers.DenseLayer(network, 800, tf.nn.relu, name='relu2') |
| 24 | +network = tl.layers.DropoutLayer(network, keep=0.5, name='drop3') |
| 25 | +# the softmax is implemented internally in tl.cost.cross_entropy(y, y_) to |
| 26 | +# speed up computation, so we use identity here. |
| 27 | +# see tf.nn.sparse_softmax_cross_entropy_with_logits() |
| 28 | +network = tl.layers.DenseLayer(network, n_units=10, |
| 29 | + act=tf.identity, name='output') |
| 30 | + |
| 31 | +# define cost function and metric. |
| 32 | +y = network.outputs |
| 33 | +cost = tl.cost.cross_entropy(y, y_, name='cost') |
| 34 | +correct_prediction = tf.equal(tf.argmax(y, 1), y_) |
| 35 | +acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) |
| 36 | +y_op = tf.argmax(tf.nn.softmax(y), 1) |
| 37 | + |
| 38 | +# define the optimizer |
| 39 | +train_params = network.all_params |
| 40 | +train_op = tf.train.AdamOptimizer(learning_rate=0.0001 |
| 41 | + ).minimize(cost, var_list=train_params) |
| 42 | + |
| 43 | +# initialize all variables in the session |
| 44 | +tl.layers.initialize_global_variables(sess) |
| 45 | + |
| 46 | +# print network information |
| 47 | +network.print_params() |
| 48 | +network.print_layers() |
| 49 | + |
| 50 | +# train the network |
| 51 | +tl.utils.fit(sess, network, train_op, cost, X_train, y_train, x, y_, |
| 52 | + acc=acc, batch_size=500, n_epoch=5, print_freq=5, |
| 53 | + X_val=X_val, y_val=y_val, eval_train=False) |
| 54 | + |
| 55 | +# evaluation |
| 56 | +tl.utils.test(sess, network, acc, X_test, y_test, x, y_, batch_size=None, cost=cost) |
| 57 | + |
| 58 | +# save the network to .npz file |
| 59 | +tl.files.save_npz(network.all_params , name='model.npz') |
| 60 | +sess.close() |
0 commit comments