Skip to content

Commit 24dcda6

Browse files
authored
Merge branch 'master' into Release
2 parents a3a4b28 + 6b4f39a commit 24dcda6

File tree

3 files changed

+16
-10
lines changed

3 files changed

+16
-10
lines changed

tensorlayer/layers/normalization.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def _bias_scale(x, b, data_format):
9292
if data_format == 'NHWC':
9393
return x * b
9494
elif data_format == 'NCHW':
95-
return x * _to_channel_first_bias(b)
95+
return x * b
9696
else:
9797
raise ValueError('invalid data_format: %s' % data_format)
9898

@@ -102,7 +102,7 @@ def _bias_add(x, b, data_format):
102102
if data_format == 'NHWC':
103103
return tf.add(x, b)
104104
elif data_format == 'NCHW':
105-
return tf.add(x, _to_channel_first_bias(b))
105+
return tf.add(x, b)
106106
else:
107107
raise ValueError('invalid data_format: %s' % data_format)
108108

@@ -291,9 +291,9 @@ def forward(self, inputs):
291291
if self.axes is None:
292292
self.axes = [i for i in range(len(inputs.shape)) if i != self.channel_axis]
293293

294+
mean, var = tf.nn.moments(inputs, self.axes, keepdims=False)
294295
if self.is_train:
295296
# update moving_mean and moving_var
296-
mean, var = tf.nn.moments(inputs, self.axes, keepdims=False)
297297
self.moving_mean = moving_averages.assign_moving_average(
298298
self.moving_mean, mean, self.decay, zero_debias=False
299299
)

tensorlayer/layers/pooling.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -324,17 +324,19 @@ def __repr__(self):
324324
return s.format(classname=self.__class__.__name__, **self.__dict__)
325325

326326
def build(self, inputs_shape=None):
327-
self._strides = [1, self.strides[0], self.strides[1], 1]
328327
if self.data_format == 'channels_last':
328+
self._strides = [1, self.strides[0], self.strides[1], 1]
329329
self.data_format = 'NHWC'
330330
elif self.data_format == 'channels_first':
331331
self.data_format = 'NCHW'
332+
self._strides = [1, 1, self.strides[0], self.strides[1]]
332333
else:
333334
raise Exception("unsupported data format")
334335

335336
def forward(self, inputs):
336337
outputs = tf.nn.max_pool(
337-
input=inputs, ksize=self.filter_size, strides=self._strides, padding=self.padding, name=self.name
338+
input=inputs, ksize=self.filter_size, strides=self._strides, padding=self.padding, name=self.name,
339+
data_format=self.data_format
338340
)
339341
return outputs
340342

@@ -397,17 +399,19 @@ def __repr__(self):
397399
return s.format(classname=self.__class__.__name__, **self.__dict__)
398400

399401
def build(self, inputs_shape=None):
400-
self._strides = [1, self.strides[0], self.strides[1], 1]
401402
if self.data_format == 'channels_last':
402403
self.data_format = 'NHWC'
404+
self._strides = [1, self.strides[0], self.strides[1], 1]
403405
elif self.data_format == 'channels_first':
404406
self.data_format = 'NCHW'
407+
self._strides = [1, 1, self.strides[0], self.strides[1]]
405408
else:
406409
raise Exception("unsupported data format")
407410

408411
def forward(self, inputs):
409412
outputs = tf.nn.avg_pool(
410-
input=inputs, ksize=self.filter_size, strides=self._strides, padding=self.padding, name=self.name
413+
input=inputs, ksize=self.filter_size, strides=self._strides, padding=self.padding, name=self.name,
414+
data_format=self.data_format
411415
)
412416
return outputs
413417

@@ -473,11 +477,12 @@ def __repr__(self):
473477
return s.format(classname=self.__class__.__name__, **self.__dict__)
474478

475479
def build(self, inputs_shape=None):
476-
self._strides = [1, self.strides[0], self.strides[1], self.strides[2], 1]
477480
if self.data_format == 'channels_last':
478481
self.data_format = 'NDHWC'
482+
self._strides = [1, self.strides[0], self.strides[1], self.strides[2], 1]
479483
elif self.data_format == 'channels_first':
480484
self.data_format = 'NCDHW'
485+
self._strides = [1, 1, self.strides[0], self.strides[1], self.strides[2]]
481486
else:
482487
raise Exception("unsupported data format")
483488

@@ -554,11 +559,12 @@ def __repr__(self):
554559
return s.format(classname=self.__class__.__name__, **self.__dict__)
555560

556561
def build(self, inputs_shape=None):
557-
self._strides = [1, self.strides[0], self.strides[1], self.strides[2], 1]
558562
if self.data_format == 'channels_last':
559563
self.data_format = 'NDHWC'
564+
self._strides = [1, self.strides[0], self.strides[1], self.strides[2], 1]
560565
elif self.data_format == 'channels_first':
561566
self.data_format = 'NCDHW'
567+
self._strides = [1, 1, self.strides[0], self.strides[1], self.strides[2]]
562568
else:
563569
raise Exception("unsupported data format")
564570

tensorlayer/models/vgg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def forward(self, inputs):
105105

106106
inputs = inputs * 255 - np.array([123.68, 116.779, 103.939], dtype=np.float32).reshape([1, 1, 1, 3])
107107

108-
out = self.layers(inputs)
108+
out = self.layers.forward(inputs)
109109
return out
110110

111111

0 commit comments

Comments
 (0)