1717 BatchNormLayer , Conv2d , DepthwiseConv2d , FlattenLayer , GlobalMeanPool2d , InputLayer , ReshapeLayer
1818)
1919
20+ MODEL_PATH = os .path .join ("models" , "mobilenet.npz" )
21+
2022
2123def conv_block (n , n_filter , filter_size = (3 , 3 ), strides = (1 , 1 ), is_train = False , name = 'conv_block' ):
2224 # ref: https://github.com/keras-team/keras/blob/master/keras/applications/mobilenet.py
@@ -101,10 +103,10 @@ def mobilenet(x, is_train=True, reuse=False):
101103sess = tf .InteractiveSession ()
102104# tl.layers.initialize_global_variables(sess)
103105
104- if not os .path .isfile ("mobilenet.npz" ):
106+ if not os .path .isfile (MODEL_PATH ):
105107 raise Exception ("Please download mobilenet.npz from : https://github.com/tensorlayer/pretrained-models" )
106108
107- tl .files .load_and_assign_npz (sess = sess , name = 'mobilenet.npz' , network = n )
109+ tl .files .load_and_assign_npz (sess = sess , name = MODEL_PATH , network = n )
108110
109111img = tl .vis .read_image ('data/tiger.jpeg' )
110112img = tl .prepro .imresize (img , (224 , 224 )) / 255
@@ -114,4 +116,4 @@ def mobilenet(x, is_train=True, reuse=False):
114116
115117print (" End time : %.5ss" % (time .time () - start_time ))
116118print ('Predicted :' , decode_predictions ([prob ], top = 3 )[0 ])
117- # tl.files.save_npz(n.all_params, name='mobilenet.npz' , sess=sess)
119+ # tl.files.save_npz(n.all_params, name=MODEL_PATH , sess=sess)
0 commit comments