1- # /usr/bin/python
1+ #! /usr/bin/python
22# -*- coding: utf-8 -*-
33
4+ import numpy as np
45import tensorflow as tf
6+ import tensorlayer as tl
57from tensorflow .python .training import moving_averages
6-
78from tensorlayer import logging
8- from tensorlayer .decorators import deprecated_alias
99from tensorlayer .layers .core import Layer
1010from tensorlayer .layers .utils import (quantize_active_overflow , quantize_weight_overflow )
1111
@@ -22,8 +22,6 @@ class QuanConv2dWithBN(Layer):
2222
2323 Parameters
2424 ----------
25- prev_layer : :class:`Layer`
26- Previous layer.
2725 n_filter : int
2826 The number of filters.
2927 filter_size : tuple of int
@@ -51,49 +49,33 @@ class QuanConv2dWithBN(Layer):
5149 The bits of this layer's parameter
5250 bitA : int
5351 The bits of the output of previous layer
54- decay : float
55- A decay factor for `ExponentialMovingAverage`.
56- Suggest to use a large value for large dataset.
57- epsilon : float
58- Eplison.
59- is_train : boolean
60- Is being used for training or inference.
61- beta_init : initializer or None
62- The initializer for initializing beta, if None, skip beta.
63- Usually you should not skip beta unless you know what happened.
64- gamma_init : initializer or None
65- The initializer for initializing gamma, if None, skip gamma.
6652 use_gemm : boolean
6753 If True, use gemm instead of ``tf.matmul`` for inferencing. (TODO).
6854 W_init : initializer
6955 The initializer for the the weight matrix.
7056 W_init_args : dictionary
7157 The arguments for the weight matrix initializer.
72- use_cudnn_on_gpu : bool
73- Default is False.
7458 data_format : str
7559 "NHWC" or "NCHW", default is "NHWC".
60+ dilation_rate : tuple of int
61+ Specifying the dilation rate to use for dilated convolution.
62+ in_channels : int
63+ The number of in channels.
7664 name : str
7765 A unique layer name.
7866
7967 Examples
8068 ---------
81- >>> import tensorflow as tf
8269 >>> import tensorlayer as tl
83- >>> x = tf.placeholder(tf.float32, [None, 256, 256, 3])
84- >>> net = tl.layers.InputLayer(x, name='input')
85- >>> net = tl.layers.QuanConv2dWithBN(net, 64, (5, 5), (1, 1), act=tf.nn.relu, padding='SAME', is_train=is_train, bitW=bitW, bitA=bitA, name='qcnnbn1')
86- >>> net = tl.layers.MaxPool2d(net, (3, 3), (2, 2), padding='SAME', name='pool1')
87- ...
88- >>> net = tl.layers.QuanConv2dWithBN(net, 64, (5, 5), (1, 1), padding='SAME', act=tf.nn.relu, is_train=is_train, bitW=bitW, bitA=bitA, name='qcnnbn2')
89- >>> net = tl.layers.MaxPool2d(net, (3, 3), (2, 2), padding='SAME', name='pool2')
90- ...
70+ >>> net = tl.layers.Input([50, 256, 256, 3])
71+ >>> layer = tl.layers.QuanConv2dWithBN(n_filter=64, filter_size=(5,5),strides=(1,1),padding='SAME',name='qcnnbn1')
72+ >>> print(layer)
73+ >>> net = tl.layers.QuanConv2dWithBN(n_filter=64, filter_size=(5,5),strides=(1,1),padding='SAME',name='qcnnbn1')(net)
74+ >>> print(net)
9175 """
9276
93- @deprecated_alias (layer = 'prev_layer' , end_support_version = 1.9 ) # TODO remove this line for the 1.9 release
9477 def __init__ (
9578 self ,
96- prev_layer ,
9779 n_filter = 32 ,
9880 filter_size = (3 , 3 ),
9981 strides = (1 , 1 ),
@@ -102,125 +84,150 @@ def __init__(
10284 decay = 0.9 ,
10385 epsilon = 1e-5 ,
10486 is_train = False ,
105- gamma_init = tf . compat . v1 . initializers .ones ,
106- beta_init = tf . compat . v1 . initializers .zeros ,
87+ gamma_init = tl . initializers .truncated_normal ( stddev = 0.02 ) ,
88+ beta_init = tl . initializers .truncated_normal ( stddev = 0.02 ) ,
10789 bitW = 8 ,
10890 bitA = 8 ,
10991 use_gemm = False ,
110- W_init = tf . compat . v1 .initializers .truncated_normal (stddev = 0.02 ),
92+ W_init = tl .initializers .truncated_normal (stddev = 0.02 ),
11193 W_init_args = None ,
112- use_cudnn_on_gpu = None ,
113- data_format = None ,
94+ data_format = "channels_last" ,
95+ dilation_rate = (1 , 1 ),
96+ in_channels = None ,
11497 name = 'quan_cnn2d_bn' ,
11598 ):
116- super (QuanConv2dWithBN , self ).__init__ (prev_layer = prev_layer , act = act , W_init_args = W_init_args , name = name )
117-
99+ super (QuanConv2dWithBN , self ).__init__ (act = act , name = name )
100+ self .n_filter = n_filter
101+ self .filter_size = filter_size
102+ self .strides = strides
103+ self .padding = padding
104+ self .decay = decay
105+ self .epsilon = epsilon
106+ self .is_train = is_train
107+ self .gamma_init = gamma_init
108+ self .beta_init = beta_init
109+ self .bitW = bitW
110+ self .bitA = bitA
111+ self .use_gemm = use_gemm
112+ self .W_init = W_init
113+ self .W_init_args = W_init_args
114+ self .data_format = data_format
115+ self .dilation_rate = dilation_rate
116+ self .in_channels = in_channels
118117 logging .info (
119118 "QuanConv2dWithBN %s: n_filter: %d filter_size: %s strides: %s pad: %s act: %s " % (
120119 self .name , n_filter , filter_size , str (strides ), padding ,
121120 self .act .__name__ if self .act is not None else 'No Activation'
122121 )
123122 )
124123
125- x = self .inputs
126- self .inputs = quantize_active_overflow (self .inputs , bitA ) # Do not remove
124+ if self .in_channels :
125+ self .build (None )
126+ self ._built = True
127127
128128 if use_gemm :
129129 raise Exception ("TODO. The current version use tf.matmul for inferencing." )
130130
131131 if len (strides ) != 2 :
132132 raise ValueError ("len(strides) should be 2." )
133133
134- try :
135- pre_channel = int (prev_layer .outputs .get_shape ()[- 1 ])
136- except Exception : # if pre_channel is ?, it happens when using Spatial Transformer Net
137- pre_channel = 1
138- logging .warning ("[warnings] unknow input channels, set to 1" )
139-
140- shape = (filter_size [0 ], filter_size [1 ], pre_channel , n_filter )
141- strides = (1 , strides [0 ], strides [1 ], 1 )
142-
143- with tf .compat .v1 .variable_scope (name ):
144- W = tf .compat .v1 .get_variable (
145- name = 'W_conv2d' , shape = shape , initializer = W_init , dtype = LayersConfig .tf_dtype , ** self .W_init_args
146- )
147-
148- conv = tf .nn .conv2d (
149- x , W , strides = strides , padding = padding , use_cudnn_on_gpu = use_cudnn_on_gpu , data_format = data_format
150- )
151-
152- para_bn_shape = conv .get_shape ()[- 1 :]
153-
154- if gamma_init :
155- scale_para = tf .compat .v1 .get_variable (
156- name = 'scale_para' , shape = para_bn_shape , initializer = gamma_init , dtype = LayersConfig .tf_dtype ,
157- trainable = is_train
158- )
159- else :
160- scale_para = None
161-
162- if beta_init :
163- offset_para = tf .compat .v1 .get_variable (
164- name = 'offset_para' , shape = para_bn_shape , initializer = beta_init , dtype = LayersConfig .tf_dtype ,
165- trainable = is_train
166- )
167- else :
168- offset_para = None
169-
170- moving_mean = tf .compat .v1 .get_variable (
171- 'moving_mean' , para_bn_shape , initializer = tf .compat .v1 .initializers .constant (1. ),
172- dtype = LayersConfig .tf_dtype , trainable = False
134+ def __repr__ (self ):
135+ actstr = self .act .__name__ if self .act is not None else 'No Activation'
136+ s = (
137+ '{classname}(in_channels={in_channels}, out_channels={n_filter}, kernel_size={filter_size}'
138+ ', strides={strides}, padding={padding}' + actstr
139+ )
140+ if self .dilation_rate != (1 , ) * len (self .dilation_rate ):
141+ s += ', dilation={dilation_rate}'
142+ if self .name is not None :
143+ s += ', name=\' {name}\' '
144+ s += ')'
145+ return s .format (classname = self .__class__ .__name__ , ** self .__dict__ )
146+
147+ def build (self , inputs_shape ):
148+ if self .data_format == 'channels_last' :
149+ self .data_format = 'NHWC'
150+ if self .in_channels is None :
151+ self .in_channels = inputs_shape [- 1 ]
152+ self ._strides = [1 , self .strides [0 ], self .strides [1 ], 1 ]
153+ self ._dilation_rate = [1 , self .dilation_rate [0 ], self .dilation_rate [1 ], 1 ]
154+ elif self .data_format == 'channels_first' :
155+ self .data_format = 'NCHW'
156+ if self .in_channels is None :
157+ self .in_channels = inputs_shape [1 ]
158+ self ._strides = [1 , 1 , self .strides [0 ], self .strides [1 ]]
159+ self ._dilation_rate = [1 , 1 , self .dilation_rate [0 ], self .dilation_rate [1 ]]
160+ else :
161+ raise Exception ("data_format should be either channels_last or channels_first" )
162+
163+ self .filter_shape = (self .filter_size [0 ], self .filter_size [1 ], self .in_channels , self .n_filter )
164+ self .W = self ._get_weights ("filters" , shape = self .filter_shape , init = self .W_init )
165+
166+ para_bn_shape = (self .n_filter , )
167+ if self .gamma_init :
168+ self .scale_para = self ._get_weights (
169+ "scale_para" , shape = para_bn_shape , init = self .gamma_init , trainable = self .is_train
173170 )
171+ else :
172+ self .scale_para = None
174173
175- moving_variance = tf .compat .v1 .get_variable (
176- 'moving_variance' ,
177- para_bn_shape ,
178- initializer = tf .compat .v1 .initializers .constant (1. ),
179- dtype = LayersConfig .tf_dtype ,
180- trainable = False ,
174+ if self .beta_init :
175+ self .offset_para = self ._get_weights (
176+ "offset_para" , shape = para_bn_shape , init = self .beta_init , trainable = self .is_train
181177 )
178+ else :
179+ self .offset_para = None
182180
183- mean , variance = tf .nn .moments (x = conv , axes = list (range (len (conv .get_shape ()) - 1 )))
184-
185- update_moving_mean = moving_averages .assign_moving_average (
186- moving_mean , mean , decay , zero_debias = False
187- ) # if zero_debias=True, has bias
188-
189- update_moving_variance = moving_averages .assign_moving_average (
190- moving_variance , variance , decay , zero_debias = False
191- ) # if zero_debias=True, has bias
181+ self .moving_mean = self ._get_weights (
182+ "moving_mean" , shape = para_bn_shape , init = tl .initializers .constant (1.0 ), trainable = False
183+ )
184+ self .moving_variance = self ._get_weights (
185+ "moving_variance" , shape = para_bn_shape , init = tl .initializers .constant (1.0 ), trainable = False
186+ )
192187
193- def mean_var_with_update ():
194- with tf .control_dependencies ([update_moving_mean , update_moving_variance ]):
195- return tf .identity (mean ), tf .identity (variance )
188+ def forward (self , inputs ):
189+ x = inputs
190+ inputs = quantize_active_overflow (inputs , self .bitA ) # Do not remove
191+ outputs = tf .nn .conv2d (
192+ input = x , filters = self .W , strides = self ._strides , padding = self .padding , data_format = self .data_format ,
193+ dilations = self ._dilation_rate , name = self .name
194+ )
196195
197- if is_train :
198- mean , var = mean_var_with_update ()
199- else :
200- mean , var = moving_mean , moving_variance
196+ mean , variance = tf .nn .moments (outputs , axes = list (range (len (outputs .get_shape ()) - 1 )))
201197
202- w_fold = _w_fold (W , scale_para , var , epsilon )
203- bias_fold = _bias_fold (offset_para , scale_para , mean , var , epsilon )
198+ update_moving_mean = moving_averages .assign_moving_average (
199+ self .moving_mean , mean , self .decay , zero_debias = False
200+ ) # if zero_debias=True, has bias
201+ update_moving_variance = moving_averages .assign_moving_average (
202+ self .moving_variance , mean , self .decay , zero_debias = False
203+ ) # if zero_debias=True, has bias
204204
205- W = quantize_weight_overflow (w_fold , bitW )
205+ if self .is_train :
206+ mean , var = self .mean_var_with_update (update_moving_mean , update_moving_variance , mean , variance )
207+ else :
208+ mean , var = self .moving_mean , self .moving_variance
206209
207- conv_fold = tf .nn .conv2d (
208- self .inputs , W , strides = strides , padding = padding , use_cudnn_on_gpu = use_cudnn_on_gpu ,
209- data_format = data_format
210- )
210+ w_fold = self ._w_fold (self .W , self .scale_para , var , self .epsilon )
211211
212- self . outputs = tf . nn . bias_add ( conv_fold , bias_fold , name = 'bn_bias_add' )
212+ W_ = quantize_weight_overflow ( w_fold , self . bitW )
213213
214- self .outputs = self ._apply_activation ( self .outputs )
214+ conv_fold = tf . nn . conv2d ( inputs , W_ , strides = self .strides , padding = self .padding , data_format = self .data_format )
215215
216- self ._add_layers (self .outputs )
216+ if self .beta_init :
217+ bias_fold = self ._bias_fold (self .offset_para , self .scale_para , mean , var , self .epsilon )
218+ conv_fold = tf .nn .bias_add (conv_fold , bias_fold , name = 'bn_bias_add' )
217219
218- self ._add_params ([W , scale_para , offset_para , moving_mean , moving_variance ])
220+ if self .act :
221+ conv_fold = self .act (conv_fold )
219222
223+ return conv_fold
220224
221- def _w_fold (w , gama , var , epsilon ):
222- return tf .compat .v1 .div (tf .multiply (gama , w ), tf .sqrt (var + epsilon ))
225+ def mean_var_with_update (self , update_moving_mean , update_moving_variance , mean , variance ):
226+ with tf .control_dependencies ([update_moving_mean , update_moving_variance ]):
227+ return tf .identity (mean ), tf .identity (variance )
223228
229+ def _w_fold (self , w , gama , var , epsilon ):
230+ return tf .compat .v1 .div (tf .multiply (gama , w ), tf .sqrt (var + epsilon ))
224231
225- def _bias_fold (beta , gama , mean , var , epsilon ):
226- return tf .subtract (beta , tf .compat .v1 .div (tf .multiply (gama , mean ), tf .sqrt (var + epsilon )))
232+ def _bias_fold (self , beta , gama , mean , var , epsilon ):
233+ return tf .subtract (beta , tf .compat .v1 .div (tf .multiply (gama , mean ), tf .sqrt (var + epsilon )))
0 commit comments