Skip to content

Commit 52c6556

Browse files
author
Allen Wang
committed
Cherry-pick classifier-trainer to r2.2.0.
1 parent 20f16bb commit 52c6556

File tree

9 files changed

+691
-61
lines changed

9 files changed

+691
-61
lines changed

official/benchmark/keras_imagenet_benchmark.py

Lines changed: 548 additions & 14 deletions
Large diffs are not rendered by default.

official/modeling/hyperparams/base_config.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,6 @@ class RuntimeConfig(Config):
257257
258258
Attributes:
259259
distribution_strategy: e.g. 'mirrored', 'tpu', etc.
260-
enable_eager: Whether or not to enable eager mode.
261260
enable_xla: Whether or not to enable XLA.
262261
per_gpu_thread_count: thread count per GPU.
263262
gpu_threads_enabled: Whether or not GPU threads are enabled.
@@ -272,9 +271,12 @@ class RuntimeConfig(Config):
272271
all_reduce_alg: Defines the algorithm for performing all-reduce.
273272
num_packs: Sets `num_packs` in the cross device ops used in
274273
MirroredStrategy. For details, see tf.distribute.NcclAllReduce.
274+
loss_scale: The type of loss scale. This is used when setting the mixed
275+
precision policy.
276+
run_eagerly: Whether or not to run the experiment eagerly.
277+
275278
"""
276279
distribution_strategy: str = 'mirrored'
277-
enable_eager: bool = False
278280
enable_xla: bool = False
279281
gpu_threads_enabled: bool = False
280282
gpu_thread_mode: Optional[str] = None
@@ -286,6 +288,8 @@ class RuntimeConfig(Config):
286288
task_index: int = -1
287289
all_reduce_alg: Optional[str] = None
288290
num_packs: int = 1
291+
loss_scale: Optional[str] = None
292+
run_eagerly: bool = False
289293

290294

291295
@dataclasses.dataclass
@@ -312,7 +316,10 @@ class CallbacksConfig(Config):
312316
Callback. Defaults to True.
313317
enable_tensorboard: Whether or not to enable Tensorboard as a Callback.
314318
Defaults to True.
319+
enable_time_history: Whether or not to enable TimeHistory Callbacks.
320+
Defaults to True.
315321
316322
"""
317323
enable_checkpoint_and_export: bool = True
318324
enable_tensorboard: bool = True
325+
enable_time_history: bool = True

official/vision/image_classification/callbacks.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Lint as: python3
12
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
23
#
34
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -22,15 +23,20 @@
2223
from absl import logging
2324

2425
import tensorflow as tf
25-
from typing import Any, List, MutableMapping, Text
26+
from typing import Any, List, MutableMapping
27+
28+
from official.utils.misc import keras_utils
2629

2730

2831
def get_callbacks(model_checkpoint: bool = True,
2932
include_tensorboard: bool = True,
33+
time_history: bool = True,
3034
track_lr: bool = True,
3135
write_model_weights: bool = True,
3236
initial_step: int = 0,
33-
model_dir: Text = None) -> List[tf.keras.callbacks.Callback]:
37+
batch_size: int = 0,
38+
log_steps: int = 0,
39+
model_dir: str = None) -> List[tf.keras.callbacks.Callback]:
3440
"""Get all callbacks."""
3541
model_dir = model_dir or ''
3642
callbacks = []
@@ -44,6 +50,11 @@ def get_callbacks(model_checkpoint: bool = True,
4450
track_lr=track_lr,
4551
initial_step=initial_step,
4652
write_images=write_model_weights))
53+
if time_history:
54+
callbacks.append(keras_utils.TimeHistory(
55+
batch_size,
56+
log_steps,
57+
logdir=model_dir if include_tensorboard else None))
4758
return callbacks
4859

4960

@@ -74,7 +85,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
7485
# classification loss
7586

7687
def __init__(self,
77-
log_dir: Text,
88+
log_dir: str,
7889
track_lr: bool = False,
7990
initial_step: int = 0,
8091
**kwargs):
@@ -84,7 +95,7 @@ def __init__(self,
8495

8596
def on_batch_begin(self,
8697
epoch: int,
87-
logs: MutableMapping[Text, Any] = None) -> None:
98+
logs: MutableMapping[str, Any] = None) -> None:
8899
self.step += 1
89100
if logs is None:
90101
logs = {}
@@ -93,7 +104,7 @@ def on_batch_begin(self,
93104

94105
def on_epoch_begin(self,
95106
epoch: int,
96-
logs: MutableMapping[Text, Any] = None) -> None:
107+
logs: MutableMapping[str, Any] = None) -> None:
97108
if logs is None:
98109
logs = {}
99110
metrics = self._calculate_metrics()
@@ -104,14 +115,14 @@ def on_epoch_begin(self,
104115

105116
def on_epoch_end(self,
106117
epoch: int,
107-
logs: MutableMapping[Text, Any] = None) -> None:
118+
logs: MutableMapping[str, Any] = None) -> None:
108119
if logs is None:
109120
logs = {}
110121
metrics = self._calculate_metrics()
111122
logs.update(metrics)
112123
super(CustomTensorBoard, self).on_epoch_end(epoch, logs)
113124

114-
def _calculate_metrics(self) -> MutableMapping[Text, Any]:
125+
def _calculate_metrics(self) -> MutableMapping[str, Any]:
115126
logs = {}
116127
if self._track_lr:
117128
logs['learning_rate'] = self._calculate_lr()

official/vision/image_classification/classifier_trainer.py

Lines changed: 59 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,24 @@
4444
from official.vision.image_classification.resnet import common
4545
from official.vision.image_classification.resnet import resnet_model
4646

47-
MODELS = {
48-
'efficientnet': efficientnet_model.EfficientNet.from_name,
49-
'resnet': resnet_model.resnet50,
50-
}
47+
48+
def get_models() -> Mapping[str, tf.keras.Model]:
49+
"""Returns the mapping from model type name to Keras model."""
50+
return {
51+
'efficientnet': efficientnet_model.EfficientNet.from_name,
52+
'resnet': resnet_model.resnet50,
53+
}
54+
55+
56+
def get_dtype_map() -> Mapping[str, tf.dtypes.DType]:
57+
"""Returns the mapping from dtype string representations to TF dtypes."""
58+
return {
59+
'float32': tf.float32,
60+
'bfloat16': tf.bfloat16,
61+
'float16': tf.float16,
62+
'fp32': tf.float32,
63+
'bf16': tf.bfloat16,
64+
}
5165

5266

5367
def _get_metrics(one_hot: bool) -> Mapping[Text, Any]:
@@ -120,7 +134,7 @@ def _get_dataset_builders(params: base_configs.ExperimentConfig,
120134
def get_loss_scale(params: base_configs.ExperimentConfig,
121135
fp16_default: float = 128.) -> float:
122136
"""Returns the loss scale for initializations."""
123-
loss_scale = params.model.loss.loss_scale
137+
loss_scale = params.runtime.loss_scale
124138
if loss_scale == 'dynamic':
125139
return loss_scale
126140
elif loss_scale is not None:
@@ -145,7 +159,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
145159
'name': model,
146160
},
147161
'runtime': {
148-
'enable_eager': flags_obj.enable_eager,
162+
'run_eagerly': flags_obj.run_eagerly,
149163
'tpu': flags_obj.tpu,
150164
},
151165
'train_dataset': {
@@ -154,8 +168,10 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
154168
'validation_dataset': {
155169
'data_dir': flags_obj.data_dir,
156170
},
157-
'test_dataset': {
158-
'data_dir': flags_obj.data_dir,
171+
'train': {
172+
'time_history': {
173+
'log_steps': flags_obj.log_steps,
174+
},
159175
},
160176
}
161177

@@ -212,10 +228,11 @@ def resume_from_checkpoint(model: tf.keras.Model,
212228
return int(initial_epoch)
213229

214230

215-
def initialize(params: base_configs.ExperimentConfig):
231+
def initialize(params: base_configs.ExperimentConfig,
232+
dataset_builder: dataset_factory.DatasetBuilder):
216233
"""Initializes backend related initializations."""
217234
keras_utils.set_session_config(
218-
enable_eager=params.runtime.enable_eager,
235+
enable_eager=params.runtime.run_eagerly,
219236
enable_xla=params.runtime.enable_xla)
220237
if params.runtime.gpu_threads_enabled:
221238
keras_utils.set_gpu_thread_mode_and_count(
@@ -224,20 +241,19 @@ def initialize(params: base_configs.ExperimentConfig):
224241
num_gpus=params.runtime.num_gpus,
225242
datasets_num_private_threads=params.runtime.dataset_num_private_threads)
226243

227-
dataset = params.train_dataset or params.validation_dataset
228-
performance.set_mixed_precision_policy(dataset.dtype)
244+
performance.set_mixed_precision_policy(dataset_builder.dtype)
229245

230-
if dataset.data_format:
231-
data_format = dataset.data_format
232-
elif tf.config.list_physical_devices('GPU'):
246+
if dataset_builder.config.data_format:
247+
data_format = dataset_builder.config.data_format
248+
if tf.config.list_physical_devices('GPU'):
233249
data_format = 'channels_first'
234250
else:
235251
data_format = 'channels_last'
236252
tf.keras.backend.set_image_data_format(data_format)
237253
distribution_utils.configure_cluster(
238254
params.runtime.worker_hosts,
239255
params.runtime.task_index)
240-
if params.runtime.enable_eager:
256+
if params.runtime.run_eagerly:
241257
# Enable eager execution to allow step-by-step debugging
242258
tf.config.experimental_run_functions_eagerly(True)
243259

@@ -254,7 +270,7 @@ def define_classifier_flags():
254270
default=None,
255271
help='Mode to run: `train`, `eval`, `train_and_eval` or `export`.')
256272
flags.DEFINE_bool(
257-
'enable_eager',
273+
'run_eagerly',
258274
default=None,
259275
help='Use eager execution and disable autograph for debugging.')
260276
flags.DEFINE_string(
@@ -265,6 +281,10 @@ def define_classifier_flags():
265281
'dataset',
266282
default=None,
267283
help='The name of the dataset, e.g. ImageNet, etc.')
284+
flags.DEFINE_integer(
285+
'log_steps',
286+
default=100,
287+
help='The interval of steps between logging of batch level stats.')
268288

269289

270290
def serialize_config(params: base_configs.ExperimentConfig,
@@ -307,11 +327,13 @@ def train_and_eval(
307327
train_steps = params.train.steps or train_builder.num_steps
308328
validation_steps = params.evaluation.steps or validation_builder.num_steps
309329

330+
initialize(params, train_builder)
331+
310332
logging.info('Global batch size: %d', train_builder.global_batch_size)
311333

312334
with strategy_scope:
313335
model_params = params.model.model_params.as_dict()
314-
model = MODELS[params.model.name](**model_params)
336+
model = get_models()[params.model.name](**model_params)
315337
learning_rate = optimizer_factory.build_learning_rate(
316338
params=params.model.learning_rate,
317339
batch_size=train_builder.global_batch_size,
@@ -331,8 +353,7 @@ def train_and_eval(
331353
loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
332354
model.compile(optimizer=optimizer,
333355
loss=loss_obj,
334-
metrics=metrics,
335-
run_eagerly=params.runtime.enable_eager)
356+
metrics=metrics)
336357

337358
initial_epoch = 0
338359
if params.train.resume_checkpoint:
@@ -345,26 +366,37 @@ def train_and_eval(
345366
callbacks = custom_callbacks.get_callbacks(
346367
model_checkpoint=params.train.callbacks.enable_checkpoint_and_export,
347368
include_tensorboard=params.train.callbacks.enable_tensorboard,
369+
time_history=params.train.callbacks.enable_time_history,
348370
track_lr=params.train.tensorboard.track_lr,
349371
write_model_weights=params.train.tensorboard.write_model_weights,
350372
initial_step=initial_epoch * train_steps,
373+
batch_size=train_builder.global_batch_size,
374+
log_steps=params.train.time_history.log_steps,
351375
model_dir=params.model_dir)
352376

377+
if params.evaluation.skip_eval:
378+
validation_kwargs = {}
379+
else:
380+
validation_kwargs = {
381+
'validation_data': validation_dataset,
382+
'validation_steps': validation_steps,
383+
'validation_freq': params.evaluation.epochs_between_evals,
384+
}
385+
353386
history = model.fit(
354387
train_dataset,
355388
epochs=train_epochs,
356389
steps_per_epoch=train_steps,
357390
initial_epoch=initial_epoch,
358391
callbacks=callbacks,
359-
validation_data=validation_dataset,
360-
validation_steps=validation_steps,
361-
validation_freq=params.evaluation.epochs_between_evals)
392+
**validation_kwargs)
362393

363-
validation_output = model.evaluate(
364-
validation_dataset, steps=validation_steps, verbose=2)
394+
validation_output = None
395+
if not params.evaluation.skip_eval:
396+
validation_output = model.evaluate(
397+
validation_dataset, steps=validation_steps, verbose=2)
365398

366399
# TODO(dankondratyuk): eval and save final test accuracy
367-
368400
stats = common.build_stats(history,
369401
validation_output,
370402
callbacks)
@@ -375,7 +407,7 @@ def export(params: base_configs.ExperimentConfig):
375407
"""Runs the model export functionality."""
376408
logging.info('Exporting model.')
377409
model_params = params.model.model_params.as_dict()
378-
model = MODELS[params.model.name](**model_params)
410+
model = get_models()[params.model.name](**model_params)
379411
checkpoint = params.export.checkpoint
380412
if checkpoint is None:
381413
logging.info('No export checkpoint was provided. Using the latest '
@@ -398,8 +430,6 @@ def run(flags_obj: flags.FlagValues,
398430
Dictionary of training/eval stats
399431
"""
400432
params = _get_params_from_flags(flags_obj)
401-
initialize(params)
402-
403433
if params.mode == 'train_and_eval':
404434
return train_and_eval(params, strategy_override)
405435
elif params.mode == 'export_only':

official/vision/image_classification/classifier_trainer_test.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,8 @@ def test_get_model_size(self, model, model_name, expected):
239239
)
240240
def test_get_loss_scale(self, loss_scale, dtype, expected):
241241
config = base_configs.ExperimentConfig(
242-
model=base_configs.ModelConfig(
243-
loss=base_configs.LossConfig(loss_scale=loss_scale)),
242+
runtime=base_configs.RuntimeConfig(
243+
loss_scale=loss_scale),
244244
train_dataset=dataset_factory.DatasetConfig(dtype=dtype))
245245
ls = classifier_trainer.get_loss_scale(config, fp16_default=128)
246246
self.assertEqual(ls, expected)
@@ -252,7 +252,7 @@ def test_get_loss_scale(self, loss_scale, dtype, expected):
252252
def test_initialize(self, dtype):
253253
config = base_configs.ExperimentConfig(
254254
runtime=base_configs.RuntimeConfig(
255-
enable_eager=False,
255+
run_eagerly=False,
256256
enable_xla=False,
257257
gpu_threads_enabled=True,
258258
per_gpu_thread_count=1,
@@ -264,7 +264,14 @@ def test_initialize(self, dtype):
264264
model=base_configs.ModelConfig(
265265
loss=base_configs.LossConfig(loss_scale='dynamic')),
266266
)
267-
classifier_trainer.initialize(config)
267+
268+
class EmptyClass:
269+
pass
270+
fake_ds_builder = EmptyClass()
271+
fake_ds_builder.dtype = dtype
272+
fake_ds_builder.config = EmptyClass()
273+
fake_ds_builder.config.data_format = None
274+
classifier_trainer.initialize(config, fake_ds_builder)
268275

269276
def test_resume_from_checkpoint(self):
270277
"""Tests functionality for resuming from checkpoint."""

0 commit comments

Comments
 (0)