77from tensorflow .python .training import moving_averages
88from tensorlayer import logging
99from tensorlayer .layers .core import Layer
10- from tensorlayer .layers .utils import (quantize_active_overflow ,
11- quantize_weight_overflow )
10+ from tensorlayer .layers .utils import (quantize_active_overflow , quantize_weight_overflow )
1211
1312# from tensorlayer.layers.core import LayersConfig
1413
@@ -76,26 +75,26 @@ class QuanConv2dWithBN(Layer):
7675 """
7776
7877 def __init__ (
79- self ,
80- n_filter = 32 ,
81- filter_size = (3 , 3 ),
82- strides = (1 , 1 ),
83- padding = 'SAME' ,
84- act = None ,
85- decay = 0.9 ,
86- epsilon = 1e-5 ,
87- is_train = False ,
88- gamma_init = tl .initializers .truncated_normal (stddev = 0.02 ),
89- beta_init = tl .initializers .truncated_normal (stddev = 0.02 ),
90- bitW = 8 ,
91- bitA = 8 ,
92- use_gemm = False ,
93- W_init = tl .initializers .truncated_normal (stddev = 0.02 ),
94- W_init_args = None ,
95- data_format = "channels_last" ,
96- dilation_rate = (1 , 1 ),
97- in_channels = None ,
98- name = 'quan_cnn2d_bn' ,
78+ self ,
79+ n_filter = 32 ,
80+ filter_size = (3 , 3 ),
81+ strides = (1 , 1 ),
82+ padding = 'SAME' ,
83+ act = None ,
84+ decay = 0.9 ,
85+ epsilon = 1e-5 ,
86+ is_train = False ,
87+ gamma_init = tl .initializers .truncated_normal (stddev = 0.02 ),
88+ beta_init = tl .initializers .truncated_normal (stddev = 0.02 ),
89+ bitW = 8 ,
90+ bitA = 8 ,
91+ use_gemm = False ,
92+ W_init = tl .initializers .truncated_normal (stddev = 0.02 ),
93+ W_init_args = None ,
94+ data_format = "channels_last" ,
95+ dilation_rate = (1 , 1 ),
96+ in_channels = None ,
97+ name = 'quan_cnn2d_bn' ,
9998 ):
10099 super (QuanConv2dWithBN , self ).__init__ (act = act , name = name )
101100 self .n_filter = n_filter
@@ -160,22 +159,18 @@ def build(self, inputs_shape):
160159 self .filter_shape = (self .filter_size [0 ], self .filter_size [1 ], self .in_channels , self .n_filter )
161160 self .W = self ._get_weights ("filters" , shape = self .filter_shape , init = self .W_init )
162161
163- para_bn_shape = (self .n_filter ,)
162+ para_bn_shape = (self .n_filter , )
164163 if self .gamma_init :
165164 self .scale_para = self ._get_weights (
166- "scale_para" ,
167- shape = para_bn_shape ,
168- init = self .gamma_init ,
169- trainable = self .is_train )
165+ "scale_para" , shape = para_bn_shape , init = self .gamma_init , trainable = self .is_train
166+ )
170167 else :
171168 self .scale_para = None
172169
173170 if self .beta_init :
174171 self .offset_para = self ._get_weights (
175- "offset_para" ,
176- shape = para_bn_shape ,
177- init = self .beta_init ,
178- trainable = self .is_train )
172+ "offset_para" , shape = para_bn_shape , init = self .beta_init , trainable = self .is_train
173+ )
179174 else :
180175 self .offset_para = None
181176
@@ -190,21 +185,18 @@ def forward(self, inputs):
190185 x = inputs
191186 inputs = quantize_active_overflow (inputs , self .bitA ) # Do not remove
192187 outputs = tf .nn .conv2d (
193- input = x ,
194- filters = self .W ,
195- strides = self ._strides ,
196- padding = self .padding ,
197- data_format = self .data_format ,
198- dilations = self ._dilation_rate ,
199- name = self .name
188+ input = x , filters = self .W , strides = self ._strides , padding = self .padding , data_format = self .data_format ,
189+ dilations = self ._dilation_rate , name = self .name
200190 )
201191
202192 mean , variance = tf .nn .moments (outputs , axes = list (range (len (outputs .get_shape ()) - 1 )))
203193
204194 update_moving_mean = moving_averages .assign_moving_average (
205- self .moving_mean , mean , self .decay , zero_debias = False ) # if zero_debias=True, has bias
195+ self .moving_mean , mean , self .decay , zero_debias = False
196+ ) # if zero_debias=True, has bias
206197 update_moving_variance = moving_averages .assign_moving_average (
207- self .moving_variance , mean , self .decay , zero_debias = False ) # if zero_debias=True, has bias
198+ self .moving_variance , mean , self .decay , zero_debias = False
199+ ) # if zero_debias=True, has bias
208200
209201 if self .is_train :
210202 mean , var = self .mean_var_with_update (update_moving_mean , update_moving_variance , mean , variance )
@@ -215,9 +207,7 @@ def forward(self, inputs):
215207
216208 W_ = quantize_weight_overflow (w_fold , self .bitW )
217209
218- conv_fold = tf .nn .conv2d (
219- inputs , W_ , strides = self .strides , padding = self .padding , data_format = self .data_format
220- )
210+ conv_fold = tf .nn .conv2d (inputs , W_ , strides = self .strides , padding = self .padding , data_format = self .data_format )
221211
222212 if self .beta_init :
223213 bias_fold = self ._bias_fold (self .offset_para , self .scale_para , mean , var , self .epsilon )
0 commit comments