Skip to content

Commit da15d54

Browse files
zsdonghaowagamamaz
authored andcommitted
squeezenet example (#428)
* update docs * release squeezenet example * release binarynet and squeezenet example * fix typo * fix typo
1 parent b621d40 commit da15d54

File tree

5 files changed

+129
-4
lines changed

5 files changed

+129
-4
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ Examples can be found [in this folder](https://github.com/zsdonghao/tensorlayer/
8484
- VGG 16 (ImageNet). Classification task, see [tutorial_vgg16.py](https://github.com/zsdonghao/tensorlayer/blob/master/example/tutorial_vgg16.py).
8585
- VGG 19 (ImageNet). Classification task, see [tutorial_vgg19.py](https://github.com/zsdonghao/tensorlayer/blob/master/example/tutorial_vgg19.py).
8686
- InceptionV3 (ImageNet). Classification task, see [tutorial\_inceptionV3_tfslim.py](https://github.com/zsdonghao/tensorlayer/blob/master/example/tutorial_inceptionV3_tfslim.py).
87+
- SqueezeNet (ImageNet). Classification task, see [tutorial_squeezenet.py](https://github.com/zsdonghao/tensorlayer/blob/master/example/tutorial_squeezenet.py)
88+
- BinaryNet (MNIST). Classification task, see [tutorial_squeezenet.py](https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_binarynet_mnist_cnn.py)
8789
- Wide ResNet (CIFAR) by [ritchieng](https://github.com/ritchieng/wideresnet-tensorlayer).
8890
- More CNN implementations of [TF-Slim](https://github.com/tensorflow/models/tree/master/research/slim) can be connected to TensorLayer via SlimNetsLayer.
8991
- [Spatial Transformer Networks](https://arxiv.org/abs/1506.02025) by [zsdonghao](https://github.com/zsdonghao/Spatial-Transformer-Nets).

docs/modules/layers.rst

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -805,11 +805,9 @@ Binary Nets
805805
Read Me
806806
^^^^^^^^^^^^^^
807807

808-
This is an experimental API package for building Binary Nets.
809-
We are using matrix multiplication rather than add-minus and bit-count operation at the moment.
810-
Therefore, these APIs would not speed up the inferencing, for production, you can train model via TensorLayer and deploy the model into other customized C/C++ implementation (We probably provide users an extra C/C++ binary net framework that can load model from TensorLayer).
808+
This is an experimental API package for building Binary Nets. We are using matrix multiplication rather than add-minus and bit-count operation at the moment. Therefore, these APIs would not speed up the inferencing, for production, you can train model via TensorLayer and deploy the model into other customized C/C++ implementation (We probably provide users an extra C/C++ binary net framework that can load model from TensorLayer).
811809

812-
Note that, these experimental APIs can be changed in anytime.
810+
Note that, these experimental APIs can be changed in the future
813811

814812
Binarized Dense
815813
^^^^^^^^^^^^^^^^^

docs/user/example.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ Computer Vision
2121
- VGG 16 (ImageNet). Classification task, see `tutorial_vgg16.py <https://github.com/zsdonghao/tensorlayer/blob/master/example/tutorial_vgg16.py>`_.
2222
- VGG 19 (ImageNet). Classification task, see `tutorial_vgg19.py <https://github.com/zsdonghao/tensorlayer/blob/master/example/tutorial_vgg19.py>`_.
2323
- InceptionV3 (ImageNet). Classification task, see `tutorial_inceptionV3_tfslim.py <https://github.com/zsdonghao/tensorlayer/blob/master/example/tutorial_inceptionV3_tfslim.py>`_.
24+
- SqueezeNet (ImageNet). Classification task, see `tutorial_squeezenet.py <https://github.com/zsdonghao/tensorlayer/blob/master/example/tutorial_squeezenet.py>`_.
25+
- BinaryNet (MNIST). Classification task, see `tutorial_squeezenet.py <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_binarynet_mnist_cnn.py>`_.
2426
- Wide ResNet (CIFAR) by `ritchieng <https://github.com/ritchieng/wideresnet-tensorlayer>`_.
2527
- More CNN implementations of `TF-Slim <https://github.com/tensorflow/models/tree/master/research/slim>`_ can be connected to TensorLayer via SlimNetsLayer.
2628
- `Spatial Transformer Networks <https://arxiv.org/abs/1506.02025>`_ by `zsdonghao <https://github.com/zsdonghao/Spatial-Transformer-Nets>`__.

example/data/imagenet_class_index.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

example/tutorial_squeezenet.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
#! /usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
4+
import time, os, json
5+
import numpy as np
6+
import tensorflow as tf
7+
import tensorlayer as tl
8+
from tensorlayer.layers import InputLayer, Conv2d, MaxPool2d, ConcatLayer, DropoutLayer, GlobalMeanPool2d
9+
10+
11+
def decode_predictions(preds, top=5): # keras.applications.resnet50
12+
fpath = os.path.join("data", "imagenet_class_index.json")
13+
if tl.files.file_exists(fpath) is False:
14+
raise Exception("{} / download imagenet_class_index.json from: https://github.com/zsdonghao/tensorlayer/tree/master/example/data")
15+
if isinstance(preds, np.ndarray) is False:
16+
preds = np.asarray(preds)
17+
if len(preds.shape) != 2 or preds.shape[1] != 1000:
18+
raise ValueError('`decode_predictions` expects '
19+
'a batch of predictions '
20+
'(i.e. a 2D array of shape (samples, 1000)). '
21+
'Found array with shape: ' + str(preds.shape))
22+
with open(fpath) as f:
23+
CLASS_INDEX = json.load(f)
24+
results = []
25+
for pred in preds:
26+
top_indices = pred.argsort()[-top:][::-1]
27+
result = [tuple(CLASS_INDEX[str(i)]) + (pred[i], ) for i in top_indices]
28+
result.sort(key=lambda x: x[2], reverse=True)
29+
results.append(result)
30+
return results
31+
32+
33+
def squeezenet(x, is_train=True, reuse=False):
34+
# model from: https://github.com/wohlert/keras-squeezenet
35+
# https://github.com/DT42/squeezenet_demo/blob/master/model.py
36+
with tf.variable_scope("squeezenet", reuse=reuse):
37+
with tf.variable_scope("input"):
38+
n = InputLayer(x)
39+
# n = Conv2d(n, 96, (7,7),(2,2),tf.nn.relu,'SAME',name='conv1')
40+
n = Conv2d(n, 64, (3, 3), (2, 2), tf.nn.relu, 'SAME', name='conv1')
41+
n = MaxPool2d(n, (3, 3), (2, 2), 'VALID', name='max')
42+
43+
with tf.variable_scope("fire2"):
44+
n = Conv2d(n, 16, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='squeeze1x1')
45+
n1 = Conv2d(n, 64, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='expand1x1')
46+
n2 = Conv2d(n, 64, (3, 3), (1, 1), tf.nn.relu, 'SAME', name='expand3x3')
47+
n = ConcatLayer([n1, n2], -1, name='concat')
48+
49+
with tf.variable_scope("fire3"):
50+
n = Conv2d(n, 16, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='squeeze1x1')
51+
n1 = Conv2d(n, 64, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='expand1x1')
52+
n2 = Conv2d(n, 64, (3, 3), (1, 1), tf.nn.relu, 'SAME', name='expand3x3')
53+
n = ConcatLayer([n1, n2], -1, name='concat')
54+
n = MaxPool2d(n, (3, 3), (2, 2), 'VALID', name='max')
55+
56+
with tf.variable_scope("fire4"):
57+
n = Conv2d(n, 32, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='squeeze1x1')
58+
n1 = Conv2d(n, 128, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='expand1x1')
59+
n2 = Conv2d(n, 128, (3, 3), (1, 1), tf.nn.relu, 'SAME', name='expand3x3')
60+
n = ConcatLayer([n1, n2], -1, name='concat')
61+
62+
with tf.variable_scope("fire5"):
63+
n = Conv2d(n, 32, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='squeeze1x1')
64+
n1 = Conv2d(n, 128, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='expand1x1')
65+
n2 = Conv2d(n, 128, (3, 3), (1, 1), tf.nn.relu, 'SAME', name='expand3x3')
66+
n = ConcatLayer([n1, n2], -1, name='concat')
67+
n = MaxPool2d(n, (3, 3), (2, 2), 'VALID', name='max')
68+
69+
with tf.variable_scope("fire6"):
70+
n = Conv2d(n, 48, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='squeeze1x1')
71+
n1 = Conv2d(n, 192, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='expand1x1')
72+
n2 = Conv2d(n, 192, (3, 3), (1, 1), tf.nn.relu, 'SAME', name='expand3x3')
73+
n = ConcatLayer([n1, n2], -1, name='concat')
74+
75+
with tf.variable_scope("fire7"):
76+
n = Conv2d(n, 48, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='squeeze1x1')
77+
n1 = Conv2d(n, 192, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='expand1x1')
78+
n2 = Conv2d(n, 192, (3, 3), (1, 1), tf.nn.relu, 'SAME', name='expand3x3')
79+
n = ConcatLayer([n1, n2], -1, name='concat')
80+
81+
with tf.variable_scope("fire8"):
82+
n = Conv2d(n, 64, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='squeeze1x1')
83+
n1 = Conv2d(n, 256, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='expand1x1')
84+
n2 = Conv2d(n, 256, (3, 3), (1, 1), tf.nn.relu, 'SAME', name='expand3x3')
85+
n = ConcatLayer([n1, n2], -1, name='concat')
86+
87+
with tf.variable_scope("fire9"):
88+
n = Conv2d(n, 64, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='squeeze1x1')
89+
n1 = Conv2d(n, 256, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='expand1x1')
90+
n2 = Conv2d(n, 256, (3, 3), (1, 1), tf.nn.relu, 'SAME', name='expand3x3')
91+
n = ConcatLayer([n1, n2], -1, name='concat')
92+
93+
with tf.variable_scope("output"):
94+
n = DropoutLayer(n, keep=0.5, is_fix=True, is_train=is_train, name='drop1')
95+
n = Conv2d(n, 1000, (1, 1), (1, 1), padding='VALID', name='conv10') # 13, 13, 1000
96+
n = GlobalMeanPool2d(n)
97+
return n
98+
99+
100+
x = tf.placeholder(tf.float32, (None, 224, 224, 3))
101+
n = squeezenet(x, False, False)
102+
softmax = tf.nn.softmax(n.outputs)
103+
n.print_layers()
104+
n.print_params(False)
105+
106+
sess = tf.InteractiveSession()
107+
tl.layers.initialize_global_variables(sess)
108+
109+
if tl.files.file_exists('squeezenet.npz'):
110+
tl.files.load_and_assign_npz(sess=sess, name='squeezenet.npz', network=n)
111+
else:
112+
raise Exception("please download the pre-trained squeezenet.npz from https://github.com/tensorlayer/pretrained-models")
113+
114+
img = tl.vis.read_image('data/tiger.jpeg', '')
115+
img = tl.prepro.imresize(img, (224, 224))
116+
prob = sess.run(softmax, feed_dict={x: [img]})[0]
117+
start_time = time.time()
118+
prob = sess.run(softmax, feed_dict={x: [img]})[0]
119+
print(" End time : %.5ss" % (time.time() - start_time))
120+
121+
print('Predicted:', decode_predictions([prob], top=3)[0])
122+
tl.files.save_npz(n.all_params, name='squeezenet.npz', sess=sess)

0 commit comments

Comments
 (0)