Skip to content

Commit b86ffb1

Browse files
saberkuntensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 301639338
1 parent febaae9 commit b86ffb1

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

official/vision/image_classification/resnet/resnet_model.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -255,9 +255,7 @@ def resnet50(num_classes,
255255
x = img_input
256256

257257
if backend.image_data_format() == 'channels_first':
258-
x = layers.Lambda(
259-
lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
260-
name='transpose')(x)
258+
x = layers.Permute((3, 1, 2))(x)
261259
bn_axis = 1
262260
else: # channels_last
263261
bn_axis = 3
@@ -382,8 +380,7 @@ def resnet50(num_classes,
382380
block='c',
383381
use_l2_regularizer=use_l2_regularizer)
384382

385-
rm_axes = [1, 2] if backend.image_data_format() == 'channels_last' else [2, 3]
386-
x = layers.Lambda(lambda x: backend.mean(x, rm_axes), name='reduce_mean')(x)
383+
x = layers.GlobalAveragePooling2D()(x)
387384
x = layers.Dense(
388385
num_classes,
389386
kernel_initializer=initializers.RandomNormal(stddev=0.01),

0 commit comments

Comments
 (0)