Skip to content

Commit d5663b3

Browse files
authored
Use reduce_mean instead of Average Pooling in Resnet (#3675)
* Try out affine layer instead of dense * Use reduce mean instead of avg pooling * Remove np * Use reduce mean instead of avg pooling * Fix axes * Cleanup * Fixing comment * Fixing tests
1 parent 5ef68f5 commit d5663b3

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

official/resnet/imagenet_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def reshape(shape):
6464
block_layer2 = graph.get_tensor_by_name('block_layer2:0')
6565
block_layer3 = graph.get_tensor_by_name('block_layer3:0')
6666
block_layer4 = graph.get_tensor_by_name('block_layer4:0')
67-
avg_pool = graph.get_tensor_by_name('final_avg_pool:0')
67+
reduce_mean = graph.get_tensor_by_name('final_reduce_mean:0')
6868
dense = graph.get_tensor_by_name('final_dense:0')
6969

7070
self.assertAllEqual(initial_conv.shape, reshape((1, 64, 112, 112)))
@@ -77,13 +77,13 @@ def reshape(shape):
7777
self.assertAllEqual(block_layer2.shape, reshape((1, 128, 28, 28)))
7878
self.assertAllEqual(block_layer3.shape, reshape((1, 256, 14, 14)))
7979
self.assertAllEqual(block_layer4.shape, reshape((1, 512, 7, 7)))
80-
self.assertAllEqual(avg_pool.shape, reshape((1, 512, 1, 1)))
80+
self.assertAllEqual(reduce_mean.shape, reshape((1, 512, 1, 1)))
8181
else:
8282
self.assertAllEqual(block_layer1.shape, reshape((1, 256, 56, 56)))
8383
self.assertAllEqual(block_layer2.shape, reshape((1, 512, 28, 28)))
8484
self.assertAllEqual(block_layer3.shape, reshape((1, 1024, 14, 14)))
8585
self.assertAllEqual(block_layer4.shape, reshape((1, 2048, 7, 7)))
86-
self.assertAllEqual(avg_pool.shape, reshape((1, 2048, 1, 1)))
86+
self.assertAllEqual(reduce_mean.shape, reshape((1, 2048, 1, 1)))
8787

8888
self.assertAllEqual(dense.shape, (1, _LABEL_CLASSES))
8989
self.assertAllEqual(output.shape, (1, _LABEL_CLASSES))

official/resnet/resnet_model.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,8 @@
3131
from __future__ import division
3232
from __future__ import print_function
3333

34-
3534
import tensorflow as tf
3635

37-
3836
_BATCH_NORM_DECAY = 0.997
3937
_BATCH_NORM_EPSILON = 1e-5
4038
DEFAULT_VERSION = 2
@@ -461,13 +459,18 @@ def __call__(self, inputs, training):
461459

462460
inputs = batch_norm(inputs, training, self.data_format)
463461
inputs = tf.nn.relu(inputs)
464-
inputs = tf.layers.average_pooling2d(
465-
inputs=inputs, pool_size=self.second_pool_size,
466-
strides=self.second_pool_stride, padding='VALID',
467-
data_format=self.data_format)
468-
inputs = tf.identity(inputs, 'final_avg_pool')
462+
463+
# The current top layer has shape
464+
# `batch_size x pool_size x pool_size x final_size`.
465+
# ResNet does an Average Pooling layer over pool_size,
466+
# but that is the same as doing a reduce_mean. We do a reduce_mean
467+
# here because it performs better than AveragePooling2D.
468+
axes = [2, 3] if self.data_format == 'channels_first' else [1, 2]
469+
inputs = tf.reduce_mean(inputs, axes, keepdims=True)
470+
inputs = tf.identity(inputs, 'final_reduce_mean')
469471

470472
inputs = tf.reshape(inputs, [-1, self.final_size])
471473
inputs = tf.layers.dense(inputs=inputs, units=self.num_classes)
472474
inputs = tf.identity(inputs, 'final_dense')
475+
473476
return inputs

0 commit comments

Comments
 (0)