Skip to content

Commit 72238e5

Browse files
saberkunallenwang28
authored andcommitted
Fix channel_first layout for efficientnet.
PiperOrigin-RevId: 304281524
1 parent b749823 commit 72238e5

File tree

4 files changed

+6
-11
lines changed

4 files changed

+6
-11
lines changed

official/vision/image_classification/classifier_trainer.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,9 +242,6 @@ def initialize(params: base_configs.ExperimentConfig,
242242
datasets_num_private_threads=params.runtime.dataset_num_private_threads)
243243

244244
performance.set_mixed_precision_policy(dataset_builder.dtype)
245-
246-
if dataset_builder.config.data_format:
247-
data_format = dataset_builder.config.data_format
248245
if tf.config.list_physical_devices('GPU'):
249246
data_format = 'channels_first'
250247
else:

official/vision/image_classification/classifier_trainer_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,6 @@ class EmptyClass:
264264
fake_ds_builder = EmptyClass()
265265
fake_ds_builder.dtype = dtype
266266
fake_ds_builder.config = EmptyClass()
267-
fake_ds_builder.config.data_format = None
268267
classifier_trainer.initialize(config, fake_ds_builder)
269268

270269
def test_resume_from_checkpoint(self):

official/vision/image_classification/dataset_factory.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,6 @@ class DatasetConfig(base_config.Config):
8787
(e.g., the number of GPUs or TPU cores).
8888
num_devices: The number of replica devices to use. This should be set by
8989
`strategy.num_replicas_in_sync` when using a distribution strategy.
90-
data_format: The data format of the images. Should be 'channels_last' or
91-
'channels_first'.
9290
dtype: The desired dtype of the dataset. This will be set during
9391
preprocessing.
9492
one_hot: Whether to apply one hot encoding. Set to `True` to be able to use
@@ -120,7 +118,6 @@ class DatasetConfig(base_config.Config):
120118
batch_size: int = 128
121119
use_per_replica_batch_size: bool = False
122120
num_devices: int = 1
123-
data_format: str = 'channels_last'
124121
dtype: str = 'float32'
125122
one_hot: bool = True
126123
augmenter: AugmentConfig = AugmentConfig()

official/vision/image_classification/efficientnet/efficientnet_model.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def conv2d_block(inputs: tf.Tensor,
166166
batch_norm = common_modules.get_batch_norm(config.batch_norm)
167167
bn_momentum = config.bn_momentum
168168
bn_epsilon = config.bn_epsilon
169-
data_format = config.data_format
169+
data_format = tf.keras.backend.image_data_format()
170170
weight_decay = config.weight_decay
171171

172172
name = name or ''
@@ -223,7 +223,7 @@ def mb_conv_block(inputs: tf.Tensor,
223223
use_se = config.use_se
224224
activation = tf_utils.get_activation(config.activation)
225225
drop_connect_rate = config.drop_connect_rate
226-
data_format = config.data_format
226+
data_format = tf.keras.backend.image_data_format()
227227
use_depthwise = block.conv_type != 'no_depthwise'
228228
prefix = prefix or ''
229229

@@ -346,12 +346,14 @@ def efficientnet(image_input: tf.keras.layers.Input,
346346
num_classes = config.num_classes
347347
input_channels = config.input_channels
348348
rescale_input = config.rescale_input
349-
data_format = config.data_format
349+
data_format = tf.keras.backend.image_data_format()
350350
dtype = config.dtype
351351
weight_decay = config.weight_decay
352352

353353
x = image_input
354-
354+
if data_format == 'channels_first':
355+
# Happens on GPU/TPU if available.
356+
x = tf.keras.layers.Permute((3, 1, 2))(x)
355357
if rescale_input:
356358
x = preprocessing.normalize_images(x,
357359
num_channels=input_channels,

0 commit comments

Comments
 (0)