|
| 1 | +#! /usr/bin/python |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | + |
| 4 | +import time, os, json |
| 5 | +import numpy as np |
| 6 | +import tensorflow as tf |
| 7 | +import tensorlayer as tl |
| 8 | +from tensorlayer.layers import InputLayer, Conv2d, MaxPool2d, ConcatLayer, DropoutLayer, GlobalMeanPool2d |
| 9 | + |
| 10 | + |
| 11 | +def decode_predictions(preds, top=5): # keras.applications.resnet50 |
| 12 | + fpath = os.path.join("data", "imagenet_class_index.json") |
| 13 | + if tl.files.file_exists(fpath) is False: |
| 14 | + raise Exception("{} / download imagenet_class_index.json from: https://github.com/zsdonghao/tensorlayer/tree/master/example/data") |
| 15 | + if isinstance(preds, np.ndarray) is False: |
| 16 | + preds = np.asarray(preds) |
| 17 | + if len(preds.shape) != 2 or preds.shape[1] != 1000: |
| 18 | + raise ValueError('`decode_predictions` expects ' |
| 19 | + 'a batch of predictions ' |
| 20 | + '(i.e. a 2D array of shape (samples, 1000)). ' |
| 21 | + 'Found array with shape: ' + str(preds.shape)) |
| 22 | + with open(fpath) as f: |
| 23 | + CLASS_INDEX = json.load(f) |
| 24 | + results = [] |
| 25 | + for pred in preds: |
| 26 | + top_indices = pred.argsort()[-top:][::-1] |
| 27 | + result = [tuple(CLASS_INDEX[str(i)]) + (pred[i], ) for i in top_indices] |
| 28 | + result.sort(key=lambda x: x[2], reverse=True) |
| 29 | + results.append(result) |
| 30 | + return results |
| 31 | + |
| 32 | + |
| 33 | +def squeezenet(x, is_train=True, reuse=False): |
| 34 | + # model from: https://github.com/wohlert/keras-squeezenet |
| 35 | + # https://github.com/DT42/squeezenet_demo/blob/master/model.py |
| 36 | + with tf.variable_scope("squeezenet", reuse=reuse): |
| 37 | + with tf.variable_scope("input"): |
| 38 | + n = InputLayer(x) |
| 39 | + # n = Conv2d(n, 96, (7,7),(2,2),tf.nn.relu,'SAME',name='conv1') |
| 40 | + n = Conv2d(n, 64, (3, 3), (2, 2), tf.nn.relu, 'SAME', name='conv1') |
| 41 | + n = MaxPool2d(n, (3, 3), (2, 2), 'VALID', name='max') |
| 42 | + |
| 43 | + with tf.variable_scope("fire2"): |
| 44 | + n = Conv2d(n, 16, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='squeeze1x1') |
| 45 | + n1 = Conv2d(n, 64, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='expand1x1') |
| 46 | + n2 = Conv2d(n, 64, (3, 3), (1, 1), tf.nn.relu, 'SAME', name='expand3x3') |
| 47 | + n = ConcatLayer([n1, n2], -1, name='concat') |
| 48 | + |
| 49 | + with tf.variable_scope("fire3"): |
| 50 | + n = Conv2d(n, 16, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='squeeze1x1') |
| 51 | + n1 = Conv2d(n, 64, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='expand1x1') |
| 52 | + n2 = Conv2d(n, 64, (3, 3), (1, 1), tf.nn.relu, 'SAME', name='expand3x3') |
| 53 | + n = ConcatLayer([n1, n2], -1, name='concat') |
| 54 | + n = MaxPool2d(n, (3, 3), (2, 2), 'VALID', name='max') |
| 55 | + |
| 56 | + with tf.variable_scope("fire4"): |
| 57 | + n = Conv2d(n, 32, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='squeeze1x1') |
| 58 | + n1 = Conv2d(n, 128, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='expand1x1') |
| 59 | + n2 = Conv2d(n, 128, (3, 3), (1, 1), tf.nn.relu, 'SAME', name='expand3x3') |
| 60 | + n = ConcatLayer([n1, n2], -1, name='concat') |
| 61 | + |
| 62 | + with tf.variable_scope("fire5"): |
| 63 | + n = Conv2d(n, 32, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='squeeze1x1') |
| 64 | + n1 = Conv2d(n, 128, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='expand1x1') |
| 65 | + n2 = Conv2d(n, 128, (3, 3), (1, 1), tf.nn.relu, 'SAME', name='expand3x3') |
| 66 | + n = ConcatLayer([n1, n2], -1, name='concat') |
| 67 | + n = MaxPool2d(n, (3, 3), (2, 2), 'VALID', name='max') |
| 68 | + |
| 69 | + with tf.variable_scope("fire6"): |
| 70 | + n = Conv2d(n, 48, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='squeeze1x1') |
| 71 | + n1 = Conv2d(n, 192, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='expand1x1') |
| 72 | + n2 = Conv2d(n, 192, (3, 3), (1, 1), tf.nn.relu, 'SAME', name='expand3x3') |
| 73 | + n = ConcatLayer([n1, n2], -1, name='concat') |
| 74 | + |
| 75 | + with tf.variable_scope("fire7"): |
| 76 | + n = Conv2d(n, 48, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='squeeze1x1') |
| 77 | + n1 = Conv2d(n, 192, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='expand1x1') |
| 78 | + n2 = Conv2d(n, 192, (3, 3), (1, 1), tf.nn.relu, 'SAME', name='expand3x3') |
| 79 | + n = ConcatLayer([n1, n2], -1, name='concat') |
| 80 | + |
| 81 | + with tf.variable_scope("fire8"): |
| 82 | + n = Conv2d(n, 64, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='squeeze1x1') |
| 83 | + n1 = Conv2d(n, 256, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='expand1x1') |
| 84 | + n2 = Conv2d(n, 256, (3, 3), (1, 1), tf.nn.relu, 'SAME', name='expand3x3') |
| 85 | + n = ConcatLayer([n1, n2], -1, name='concat') |
| 86 | + |
| 87 | + with tf.variable_scope("fire9"): |
| 88 | + n = Conv2d(n, 64, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='squeeze1x1') |
| 89 | + n1 = Conv2d(n, 256, (1, 1), (1, 1), tf.nn.relu, 'SAME', name='expand1x1') |
| 90 | + n2 = Conv2d(n, 256, (3, 3), (1, 1), tf.nn.relu, 'SAME', name='expand3x3') |
| 91 | + n = ConcatLayer([n1, n2], -1, name='concat') |
| 92 | + |
| 93 | + with tf.variable_scope("output"): |
| 94 | + n = DropoutLayer(n, keep=0.5, is_fix=True, is_train=is_train, name='drop1') |
| 95 | + n = Conv2d(n, 1000, (1, 1), (1, 1), padding='VALID', name='conv10') # 13, 13, 1000 |
| 96 | + n = GlobalMeanPool2d(n) |
| 97 | + return n |
| 98 | + |
| 99 | + |
| 100 | +x = tf.placeholder(tf.float32, (None, 224, 224, 3)) |
| 101 | +n = squeezenet(x, False, False) |
| 102 | +softmax = tf.nn.softmax(n.outputs) |
| 103 | +n.print_layers() |
| 104 | +n.print_params(False) |
| 105 | + |
| 106 | +sess = tf.InteractiveSession() |
| 107 | +tl.layers.initialize_global_variables(sess) |
| 108 | + |
| 109 | +if tl.files.file_exists('squeezenet.npz'): |
| 110 | + tl.files.load_and_assign_npz(sess=sess, name='squeezenet.npz', network=n) |
| 111 | +else: |
| 112 | + raise Exception("please download the pre-trained squeezenet.npz from https://github.com/tensorlayer/pretrained-models") |
| 113 | + |
| 114 | +img = tl.vis.read_image('data/tiger.jpeg', '') |
| 115 | +img = tl.prepro.imresize(img, (224, 224)) |
| 116 | +prob = sess.run(softmax, feed_dict={x: [img]})[0] |
| 117 | +start_time = time.time() |
| 118 | +prob = sess.run(softmax, feed_dict={x: [img]})[0] |
| 119 | +print(" End time : %.5ss" % (time.time() - start_time)) |
| 120 | + |
| 121 | +print('Predicted:', decode_predictions([prob], top=3)[0]) |
| 122 | +tl.files.save_npz(n.all_params, name='squeezenet.npz', sess=sess) |
0 commit comments