Skip to content

Commit 3b388ac

Browse files
committed
use yapf
1 parent 5bd31ae commit 3b388ac

File tree

2 files changed

+53
-68
lines changed

2 files changed

+53
-68
lines changed

tensorlayer/layers/convolution/quan_conv_bn.py

Lines changed: 33 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
from tensorflow.python.training import moving_averages
88
from tensorlayer import logging
99
from tensorlayer.layers.core import Layer
10-
from tensorlayer.layers.utils import (quantize_active_overflow,
11-
quantize_weight_overflow)
10+
from tensorlayer.layers.utils import (quantize_active_overflow, quantize_weight_overflow)
1211

1312
# from tensorlayer.layers.core import LayersConfig
1413

@@ -76,26 +75,26 @@ class QuanConv2dWithBN(Layer):
7675
"""
7776

7877
def __init__(
79-
self,
80-
n_filter=32,
81-
filter_size=(3, 3),
82-
strides=(1, 1),
83-
padding='SAME',
84-
act=None,
85-
decay=0.9,
86-
epsilon=1e-5,
87-
is_train=False,
88-
gamma_init=tl.initializers.truncated_normal(stddev=0.02),
89-
beta_init=tl.initializers.truncated_normal(stddev=0.02),
90-
bitW=8,
91-
bitA=8,
92-
use_gemm=False,
93-
W_init=tl.initializers.truncated_normal(stddev=0.02),
94-
W_init_args=None,
95-
data_format="channels_last",
96-
dilation_rate=(1, 1),
97-
in_channels=None,
98-
name='quan_cnn2d_bn',
78+
self,
79+
n_filter=32,
80+
filter_size=(3, 3),
81+
strides=(1, 1),
82+
padding='SAME',
83+
act=None,
84+
decay=0.9,
85+
epsilon=1e-5,
86+
is_train=False,
87+
gamma_init=tl.initializers.truncated_normal(stddev=0.02),
88+
beta_init=tl.initializers.truncated_normal(stddev=0.02),
89+
bitW=8,
90+
bitA=8,
91+
use_gemm=False,
92+
W_init=tl.initializers.truncated_normal(stddev=0.02),
93+
W_init_args=None,
94+
data_format="channels_last",
95+
dilation_rate=(1, 1),
96+
in_channels=None,
97+
name='quan_cnn2d_bn',
9998
):
10099
super(QuanConv2dWithBN, self).__init__(act=act, name=name)
101100
self.n_filter = n_filter
@@ -160,22 +159,18 @@ def build(self, inputs_shape):
160159
self.filter_shape = (self.filter_size[0], self.filter_size[1], self.in_channels, self.n_filter)
161160
self.W = self._get_weights("filters", shape=self.filter_shape, init=self.W_init)
162161

163-
para_bn_shape = (self.n_filter,)
162+
para_bn_shape = (self.n_filter, )
164163
if self.gamma_init:
165164
self.scale_para = self._get_weights(
166-
"scale_para",
167-
shape=para_bn_shape,
168-
init=self.gamma_init,
169-
trainable=self.is_train)
165+
"scale_para", shape=para_bn_shape, init=self.gamma_init, trainable=self.is_train
166+
)
170167
else:
171168
self.scale_para = None
172169

173170
if self.beta_init:
174171
self.offset_para = self._get_weights(
175-
"offset_para",
176-
shape=para_bn_shape,
177-
init=self.beta_init,
178-
trainable=self.is_train)
172+
"offset_para", shape=para_bn_shape, init=self.beta_init, trainable=self.is_train
173+
)
179174
else:
180175
self.offset_para = None
181176

@@ -190,21 +185,18 @@ def forward(self, inputs):
190185
x = inputs
191186
inputs = quantize_active_overflow(inputs, self.bitA) # Do not remove
192187
outputs = tf.nn.conv2d(
193-
input=x,
194-
filters=self.W,
195-
strides=self._strides,
196-
padding=self.padding,
197-
data_format=self.data_format,
198-
dilations=self._dilation_rate,
199-
name=self.name
188+
input=x, filters=self.W, strides=self._strides, padding=self.padding, data_format=self.data_format,
189+
dilations=self._dilation_rate, name=self.name
200190
)
201191

202192
mean, variance = tf.nn.moments(outputs, axes=list(range(len(outputs.get_shape()) - 1)))
203193

204194
update_moving_mean = moving_averages.assign_moving_average(
205-
self.moving_mean, mean, self.decay, zero_debias=False) # if zero_debias=True, has bias
195+
self.moving_mean, mean, self.decay, zero_debias=False
196+
) # if zero_debias=True, has bias
206197
update_moving_variance = moving_averages.assign_moving_average(
207-
self.moving_variance, mean, self.decay, zero_debias=False) # if zero_debias=True, has bias
198+
self.moving_variance, mean, self.decay, zero_debias=False
199+
) # if zero_debias=True, has bias
208200

209201
if self.is_train:
210202
mean, var = self.mean_var_with_update(update_moving_mean, update_moving_variance, mean, variance)
@@ -215,9 +207,7 @@ def forward(self, inputs):
215207

216208
W_ = quantize_weight_overflow(w_fold, self.bitW)
217209

218-
conv_fold = tf.nn.conv2d(
219-
inputs, W_, strides=self.strides, padding=self.padding, data_format=self.data_format
220-
)
210+
conv_fold = tf.nn.conv2d(inputs, W_, strides=self.strides, padding=self.padding, data_format=self.data_format)
221211

222212
if self.beta_init:
223213
bias_fold = self._bias_fold(self.offset_para, self.scale_para, mean, var, self.epsilon)

tensorlayer/layers/dense/quan_dense_bn.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
from tensorlayer import logging
99
from tensorlayer.decorators import deprecated_alias
1010
from tensorlayer.layers.core import Layer
11-
from tensorlayer.layers.utils import (quantize_active_overflow,
12-
quantize_weight_overflow)
11+
from tensorlayer.layers.utils import (quantize_active_overflow, quantize_weight_overflow)
1312

1413
__all__ = [
1514
'QuanDenseLayerWithBN',
@@ -62,20 +61,20 @@ class QuanDenseLayerWithBN(Layer):
6261
"""
6362

6463
def __init__(
65-
self,
66-
n_units=100,
67-
act=None,
68-
decay=0.9,
69-
epsilon=1e-5,
70-
is_train=False,
71-
bitW=8,
72-
bitA=8,
73-
gamma_init=tl.initializers.truncated_normal(stddev=0.05),
74-
beta_init=tl.initializers.truncated_normal(stddev=0.05),
75-
use_gemm=False,
76-
W_init=tl.initializers.truncated_normal(stddev=0.05),
77-
W_init_args=None,
78-
name=None, # 'quan_dense_with_bn',
64+
self,
65+
n_units=100,
66+
act=None,
67+
decay=0.9,
68+
epsilon=1e-5,
69+
is_train=False,
70+
bitW=8,
71+
bitA=8,
72+
gamma_init=tl.initializers.truncated_normal(stddev=0.05),
73+
beta_init=tl.initializers.truncated_normal(stddev=0.05),
74+
use_gemm=False,
75+
W_init=tl.initializers.truncated_normal(stddev=0.05),
76+
W_init_args=None,
77+
name=None, # 'quan_dense_with_bn',
7978
):
8079
super(QuanDenseLayerWithBN, self).__init__(act=act, W_init_args=W_init_args, name=name)
8180
self.n_units = n_units
@@ -115,7 +114,7 @@ def build(self, inputs_shape):
115114
n_in = inputs_shape[-1]
116115
self.W = self._get_weights("weights", shape=(n_in, self.n_units), init=self.W_init)
117116

118-
para_bn_shape = (self.n_units,)
117+
para_bn_shape = (self.n_units, )
119118
if self.gamma_init:
120119
self.scale_para = self._get_weights("gamm_weights", shape=para_bn_shape, init=self.gamma_init)
121120
else:
@@ -127,15 +126,11 @@ def build(self, inputs_shape):
127126
self.offset_para = None
128127

129128
self.moving_mean = self._get_weights(
130-
"moving_mean",
131-
shape=para_bn_shape,
132-
init=tl.initializers.constant(1.0),
133-
trainable=False)
129+
"moving_mean", shape=para_bn_shape, init=tl.initializers.constant(1.0), trainable=False
130+
)
134131
self.moving_variance = self._get_weights(
135-
"moving_variacne",
136-
shape=para_bn_shape,
137-
init=tl.initializers.constant(1.0),
138-
trainable=False)
132+
"moving_variacne", shape=para_bn_shape, init=tl.initializers.constant(1.0), trainable=False
133+
)
139134

140135
def forward(self, inputs):
141136
x = inputs

0 commit comments

Comments
 (0)