|
3 | 3 |
|
4 | 4 | import time |
5 | 5 |
|
| 6 | +import numpy as np |
6 | 7 | import tensorflow as tf |
7 | 8 |
|
8 | 9 | import tensorlayer as tl |
| 10 | +from tensorlayer.layers import (BatchNorm, BinaryConv2d, BinaryDense, Flatten, Input, MaxPool2d, Sign) |
| 11 | +from tensorlayer.models import Model |
9 | 12 |
|
10 | | -tf.logging.set_verbosity(tf.logging.DEBUG) |
11 | 13 | tl.logging.set_verbosity(tl.logging.DEBUG) |
12 | 14 |
|
13 | 15 | X_train, y_train, X_val, y_val, X_test, y_test = tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1)) |
14 | | -# X_train, y_train, X_test, y_test = tl.files.load_cropped_svhn(include_extra=False) |
15 | | - |
16 | | -sess = tf.InteractiveSession() |
17 | 16 |
|
18 | 17 | batch_size = 128 |
19 | 18 |
|
20 | | -x = tf.placeholder(tf.float32, shape=[batch_size, 28, 28, 1]) |
21 | | -y_ = tf.placeholder(tf.int64, shape=[batch_size]) |
22 | | - |
23 | 19 |
|
24 | | -def model(x, is_train=True, reuse=False): |
| 20 | +def model(inputs_shape, n_class=10): |
25 | 21 | # In BNN, all the layers inputs are binary, with the exception of the first layer. |
26 | 22 | # ref: https://github.com/itayhubara/BinaryNet.tf/blob/master/models/BNN_cifar10.py |
27 | | - with tf.variable_scope("binarynet", reuse=reuse): |
28 | | - net = tl.layers.InputLayer(x, name='input') |
29 | | - net = tl.layers.BinaryConv2d(net, 32, (5, 5), (1, 1), padding='SAME', b_init=None, name='bcnn1') |
30 | | - net = tl.layers.MaxPool2d(net, (2, 2), (2, 2), padding='SAME', name='pool1') |
31 | | - net = tl.layers.BatchNormLayer(net, act=tl.act.htanh, is_train=is_train, name='bn1') |
32 | | - |
33 | | - net = tl.layers.SignLayer(net) |
34 | | - net = tl.layers.BinaryConv2d(net, 64, (5, 5), (1, 1), padding='SAME', b_init=None, name='bcnn2') |
35 | | - net = tl.layers.MaxPool2d(net, (2, 2), (2, 2), padding='SAME', name='pool2') |
36 | | - net = tl.layers.BatchNormLayer(net, act=tl.act.htanh, is_train=is_train, name='bn2') |
37 | | - |
38 | | - net = tl.layers.FlattenLayer(net) |
39 | | - # net = tl.layers.DropoutLayer(net, 0.8, True, is_train, name='drop1') |
40 | | - net = tl.layers.SignLayer(net) |
41 | | - net = tl.layers.BinaryDenseLayer(net, 256, b_init=None, name='dense') |
42 | | - net = tl.layers.BatchNormLayer(net, act=tl.act.htanh, is_train=is_train, name='bn3') |
43 | | - |
44 | | - # net = tl.layers.DropoutLayer(net, 0.8, True, is_train, name='drop2') |
45 | | - net = tl.layers.SignLayer(net) |
46 | | - net = tl.layers.BinaryDenseLayer(net, 10, b_init=None, name='bout') |
47 | | - net = tl.layers.BatchNormLayer(net, is_train=is_train, name='bno') |
| 23 | + net_in = Input(inputs_shape, name='input') |
| 24 | + net = BinaryConv2d(32, (5, 5), (1, 1), padding='SAME', b_init=None, name='bcnn1')(net_in) |
| 25 | + net = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool1')(net) |
| 26 | + net = BatchNorm(act=tl.act.htanh, name='bn1')(net) |
| 27 | + |
| 28 | + net = Sign("sign1")(net) |
| 29 | + net = BinaryConv2d(64, (5, 5), (1, 1), padding='SAME', b_init=None, name='bcnn2')(net) |
| 30 | + net = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool2')(net) |
| 31 | + net = BatchNorm(act=tl.act.htanh, name='bn2')(net) |
| 32 | + |
| 33 | + net = Flatten('ft')(net) |
| 34 | + net = Sign("sign2")(net) |
| 35 | + net = BinaryDense(256, b_init=None, name='dense')(net) |
| 36 | + net = BatchNorm(act=tl.act.htanh, name='bn3')(net) |
| 37 | + |
| 38 | + net = Sign("sign3")(net) |
| 39 | + net = BinaryDense(10, b_init=None, name='bout')(net) |
| 40 | + net = BatchNorm(name='bno')(net) |
| 41 | + net = Model(inputs=net_in, outputs=net, name='binarynet') |
48 | 42 | return net |
49 | 43 |
|
50 | 44 |
|
51 | | -# define inferences |
52 | | -net_train = model(x, is_train=True, reuse=False) |
53 | | -net_test = model(x, is_train=False, reuse=True) |
54 | | - |
55 | | -# cost for training |
56 | | -y = net_train.outputs |
57 | | -cost = tl.cost.cross_entropy(y, y_, name='xentropy') |
| 45 | +def _train_step(network, X_batch, y_batch, cost, train_op=tf.optimizers.Adam(learning_rate=0.0001), acc=None): |
| 46 | + with tf.GradientTape() as tape: |
| 47 | + y_pred = network(X_batch) |
| 48 | + _loss = cost(y_pred, y_batch) |
| 49 | + grad = tape.gradient(_loss, network.trainable_weights) |
| 50 | + train_op.apply_gradients(zip(grad, network.trainable_weights)) |
| 51 | + if acc is not None: |
| 52 | + _acc = acc(y_pred, y_batch) |
| 53 | + return _loss, _acc |
| 54 | + else: |
| 55 | + return _loss, None |
58 | 56 |
|
59 | | -# cost and accuracy for evalution |
60 | | -y2 = net_test.outputs |
61 | | -cost_test = tl.cost.cross_entropy(y2, y_, name='xentropy2') |
62 | | -correct_prediction = tf.equal(tf.argmax(y2, 1), y_) |
63 | | -acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) |
64 | 57 |
|
65 | | -# define the optimizer |
66 | | -train_params = tl.layers.get_variables_with_name('binarynet', True, True) |
67 | | -train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(cost, var_list=train_params) |
| 58 | +def accuracy(_logits, y_batch): |
| 59 | + return np.mean(np.equal(np.argmax(_logits, 1), y_batch)) |
68 | 60 |
|
69 | | -# initialize all variables in the session |
70 | | -sess.run(tf.global_variables_initializer()) |
71 | | - |
72 | | -net_train.print_params() |
73 | | -net_train.print_layers() |
74 | 61 |
|
75 | 62 | n_epoch = 200 |
76 | 63 | print_freq = 5 |
77 | 64 |
|
78 | | -# print(sess.run(net_test.all_params)) # print real values of parameters |
| 65 | +net = model([None, 28, 28, 1]) |
| 66 | +train_op = tf.optimizers.Adam(learning_rate=0.0001) |
| 67 | +cost = tl.cost.cross_entropy |
79 | 68 |
|
80 | 69 | for epoch in range(n_epoch): |
81 | 70 | start_time = time.time() |
| 71 | + train_loss, train_acc, n_batch = 0, 0, 0 |
| 72 | + net.train() |
| 73 | + |
82 | 74 | for X_train_a, y_train_a in tl.iterate.minibatches(X_train, y_train, batch_size, shuffle=True): |
83 | | - sess.run(train_op, feed_dict={x: X_train_a, y_: y_train_a}) |
| 75 | + _loss, acc = _train_step(net, X_train_a, y_train_a, cost=cost, train_op=train_op, acc=accuracy) |
| 76 | + train_loss += _loss |
| 77 | + train_acc += acc |
| 78 | + n_batch += 1 |
| 79 | + |
| 80 | + # print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time)) |
| 81 | + # print(" train loss: %f" % (train_loss / n_batch)) |
| 82 | + # print(" train acc: %f" % (train_acc / n_batch)) |
84 | 83 |
|
85 | 84 | if epoch + 1 == 1 or (epoch + 1) % print_freq == 0: |
86 | 85 | print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time)) |
87 | | - train_loss, train_acc, n_batch = 0, 0, 0 |
88 | | - for X_train_a, y_train_a in tl.iterate.minibatches(X_train, y_train, batch_size, shuffle=True): |
89 | | - err, ac = sess.run([cost_test, acc], feed_dict={x: X_train_a, y_: y_train_a}) |
90 | | - train_loss += err |
91 | | - train_acc += ac |
92 | | - n_batch += 1 |
93 | 86 | print(" train loss: %f" % (train_loss / n_batch)) |
94 | 87 | print(" train acc: %f" % (train_acc / n_batch)) |
95 | | - val_loss, val_acc, n_batch = 0, 0, 0 |
| 88 | + val_loss, val_acc, val_batch = 0, 0, 0 |
| 89 | + net.eval() |
96 | 90 | for X_val_a, y_val_a in tl.iterate.minibatches(X_val, y_val, batch_size, shuffle=True): |
97 | | - err, ac = sess.run([cost_test, acc], feed_dict={x: X_val_a, y_: y_val_a}) |
98 | | - val_loss += err |
99 | | - val_acc += ac |
100 | | - n_batch += 1 |
101 | | - print(" val loss: %f" % (val_loss / n_batch)) |
102 | | - print(" val acc: %f" % (val_acc / n_batch)) |
103 | | - |
104 | | -print('Evaluation') |
105 | | -test_loss, test_acc, n_batch = 0, 0, 0 |
| 91 | + _logits = net(X_val_a) |
| 92 | + val_loss += tl.cost.cross_entropy(_logits, y_val_a, name='eval_loss') |
| 93 | + val_acc += np.mean(np.equal(np.argmax(_logits, 1), y_val_a)) |
| 94 | + val_batch += 1 |
| 95 | + print(" val loss: {}".format(val_loss / val_batch)) |
| 96 | + print(" val acc: {}".format(val_acc / val_batch)) |
| 97 | + |
| 98 | +net.test() |
| 99 | +test_loss, test_acc, n_test_batch = 0, 0, 0 |
106 | 100 | for X_test_a, y_test_a in tl.iterate.minibatches(X_test, y_test, batch_size, shuffle=True): |
107 | | - err, ac = sess.run([cost_test, acc], feed_dict={x: X_test_a, y_: y_test_a}) |
108 | | - test_loss += err |
109 | | - test_acc += ac |
110 | | - n_batch += 1 |
111 | | -print(" test loss: %f" % (test_loss / n_batch)) |
112 | | -print(" test acc: %f" % (test_acc / n_batch)) |
| 101 | + _logits = net(X_test_a) |
| 102 | + test_loss += tl.cost.cross_entropy(_logits, y_test_a, name='test_loss') |
| 103 | + test_acc += np.mean(np.equal(np.argmax(_logits, 1), y_test_a)) |
| 104 | + n_test_batch += 1 |
| 105 | +print(" test loss: %f" % (test_loss / n_test_batch)) |
| 106 | +print(" test acc: %f" % (test_acc / n_test_batch)) |
0 commit comments