|
7 | 7 |
|
8 | 8 | X_train, y_train, X_val, y_val, X_test, y_test = \ |
9 | 9 | tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1)) |
| 10 | +# X_train, y_train, X_test, y_test = tl.files.load_cropped_svhn(include_extra=False) |
10 | 11 |
|
11 | 12 | sess = tf.InteractiveSession() |
12 | 13 |
|
|
17 | 18 |
|
18 | 19 |
|
19 | 20 | def model(x, is_train=True, reuse=False): |
| 21 | + # In BNN, all the layers inputs are binary, with the exception of the first layer. |
| 22 | + # ref: https://github.com/itayhubara/BinaryNet.tf/blob/master/models/BNN_cifar10.py |
20 | 23 | with tf.variable_scope("binarynet", reuse=reuse): |
21 | 24 | net = tl.layers.InputLayer(x, name='input') |
22 | 25 | net = tl.layers.BinaryConv2d(net, 32, (5, 5), (1, 1), padding='SAME', name='bcnn1') |
23 | 26 | net = tl.layers.MaxPool2d(net, (2, 2), (2, 2), padding='SAME', name='pool1') |
| 27 | + net = tl.layers.BatchNormLayer(net, act=tl.act.htanh, is_train=is_train, name='bn1') |
24 | 28 |
|
25 | | - net = tl.layers.BatchNormLayer(net, is_train=is_train, name='bn') |
26 | | - net = tl.layers.SignLayer(net, name='sign2') |
| 29 | + net = tl.layers.SignLayer(net) |
27 | 30 | net = tl.layers.BinaryConv2d(net, 64, (5, 5), (1, 1), padding='SAME', name='bcnn2') |
28 | 31 | net = tl.layers.MaxPool2d(net, (2, 2), (2, 2), padding='SAME', name='pool2') |
| 32 | + net = tl.layers.BatchNormLayer(net, act=tl.act.htanh, is_train=is_train, name='bn2') |
29 | 33 |
|
30 | | - net = tl.layers.SignLayer(net, name='sign2') |
31 | 34 | net = tl.layers.FlattenLayer(net, name='flatten') |
32 | | - net = tl.layers.DropoutLayer(net, 0.5, True, is_train, name='drop1') |
33 | | - # net = tl.layers.DenseLayer(net, 256, act=tf.nn.relu, name='dense') |
| 35 | + net = tl.layers.DropoutLayer(net, 0.8, True, is_train, name='drop1') |
| 36 | + net = tl.layers.SignLayer(net) |
34 | 37 | net = tl.layers.BinaryDenseLayer(net, 256, name='dense') |
35 | | - net = tl.layers.DropoutLayer(net, 0.5, True, is_train, name='drop2') |
36 | | - # net = tl.layers.DenseLayer(net, 10, act=tf.identity, name='output') |
| 38 | + net = tl.layers.BatchNormLayer(net, act=tl.act.htanh, is_train=is_train, name='bn3') |
| 39 | + |
| 40 | + net = tl.layers.DropoutLayer(net, 0.8, True, is_train, name='drop2') |
| 41 | + net = tl.layers.SignLayer(net) |
37 | 42 | net = tl.layers.BinaryDenseLayer(net, 10, name='bout') |
38 | | - # net = tl.layers.ScaleLayer(net, name='scale') |
| 43 | + net = tl.layers.BatchNormLayer(net, is_train=is_train, name='bno') |
39 | 44 | return net |
40 | 45 |
|
41 | 46 |
|
@@ -66,7 +71,7 @@ def model(x, is_train=True, reuse=False): |
66 | 71 | n_epoch = 200 |
67 | 72 | print_freq = 5 |
68 | 73 |
|
69 | | -# print(sess.run(net_test.all_params)) # print real value of parameters |
| 74 | +# print(sess.run(net_test.all_params)) # print real values of parameters |
70 | 75 |
|
71 | 76 | for epoch in range(n_epoch): |
72 | 77 | start_time = time.time() |
|
0 commit comments