44
44
from official .vision .image_classification .resnet import common
45
45
from official .vision .image_classification .resnet import resnet_model
46
46
47
- MODELS = {
48
- 'efficientnet' : efficientnet_model .EfficientNet .from_name ,
49
- 'resnet' : resnet_model .resnet50 ,
50
- }
47
+
48
+ def get_models () -> Mapping [str , tf .keras .Model ]:
49
+ """Returns the mapping from model type name to Keras model."""
50
+ return {
51
+ 'efficientnet' : efficientnet_model .EfficientNet .from_name ,
52
+ 'resnet' : resnet_model .resnet50 ,
53
+ }
54
+
55
+
56
+ def get_dtype_map () -> Mapping [str , tf .dtypes .DType ]:
57
+ """Returns the mapping from dtype string representations to TF dtypes."""
58
+ return {
59
+ 'float32' : tf .float32 ,
60
+ 'bfloat16' : tf .bfloat16 ,
61
+ 'float16' : tf .float16 ,
62
+ 'fp32' : tf .float32 ,
63
+ 'bf16' : tf .bfloat16 ,
64
+ }
51
65
52
66
53
67
def _get_metrics (one_hot : bool ) -> Mapping [Text , Any ]:
@@ -120,7 +134,7 @@ def _get_dataset_builders(params: base_configs.ExperimentConfig,
120
134
def get_loss_scale (params : base_configs .ExperimentConfig ,
121
135
fp16_default : float = 128. ) -> float :
122
136
"""Returns the loss scale for initializations."""
123
- loss_scale = params .model . loss .loss_scale
137
+ loss_scale = params .runtime .loss_scale
124
138
if loss_scale == 'dynamic' :
125
139
return loss_scale
126
140
elif loss_scale is not None :
@@ -145,7 +159,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
145
159
'name' : model ,
146
160
},
147
161
'runtime' : {
148
- 'enable_eager ' : flags_obj .enable_eager ,
162
+ 'run_eagerly ' : flags_obj .run_eagerly ,
149
163
'tpu' : flags_obj .tpu ,
150
164
},
151
165
'train_dataset' : {
@@ -154,8 +168,10 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
154
168
'validation_dataset' : {
155
169
'data_dir' : flags_obj .data_dir ,
156
170
},
157
- 'test_dataset' : {
158
- 'data_dir' : flags_obj .data_dir ,
171
+ 'train' : {
172
+ 'time_history' : {
173
+ 'log_steps' : flags_obj .log_steps ,
174
+ },
159
175
},
160
176
}
161
177
@@ -212,10 +228,11 @@ def resume_from_checkpoint(model: tf.keras.Model,
212
228
return int (initial_epoch )
213
229
214
230
215
- def initialize (params : base_configs .ExperimentConfig ):
231
+ def initialize (params : base_configs .ExperimentConfig ,
232
+ dataset_builder : dataset_factory .DatasetBuilder ):
216
233
"""Initializes backend related initializations."""
217
234
keras_utils .set_session_config (
218
- enable_eager = params .runtime .enable_eager ,
235
+ enable_eager = params .runtime .run_eagerly ,
219
236
enable_xla = params .runtime .enable_xla )
220
237
if params .runtime .gpu_threads_enabled :
221
238
keras_utils .set_gpu_thread_mode_and_count (
@@ -224,20 +241,19 @@ def initialize(params: base_configs.ExperimentConfig):
224
241
num_gpus = params .runtime .num_gpus ,
225
242
datasets_num_private_threads = params .runtime .dataset_num_private_threads )
226
243
227
- dataset = params .train_dataset or params .validation_dataset
228
- performance .set_mixed_precision_policy (dataset .dtype )
244
+ performance .set_mixed_precision_policy (dataset_builder .dtype )
229
245
230
- if dataset .data_format :
231
- data_format = dataset .data_format
232
- elif tf .config .list_physical_devices ('GPU' ):
246
+ if dataset_builder . config .data_format :
247
+ data_format = dataset_builder . config .data_format
248
+ if tf .config .list_physical_devices ('GPU' ):
233
249
data_format = 'channels_first'
234
250
else :
235
251
data_format = 'channels_last'
236
252
tf .keras .backend .set_image_data_format (data_format )
237
253
distribution_utils .configure_cluster (
238
254
params .runtime .worker_hosts ,
239
255
params .runtime .task_index )
240
- if params .runtime .enable_eager :
256
+ if params .runtime .run_eagerly :
241
257
# Enable eager execution to allow step-by-step debugging
242
258
tf .config .experimental_run_functions_eagerly (True )
243
259
@@ -254,7 +270,7 @@ def define_classifier_flags():
254
270
default = None ,
255
271
help = 'Mode to run: `train`, `eval`, `train_and_eval` or `export`.' )
256
272
flags .DEFINE_bool (
257
- 'enable_eager ' ,
273
+ 'run_eagerly ' ,
258
274
default = None ,
259
275
help = 'Use eager execution and disable autograph for debugging.' )
260
276
flags .DEFINE_string (
@@ -265,6 +281,10 @@ def define_classifier_flags():
265
281
'dataset' ,
266
282
default = None ,
267
283
help = 'The name of the dataset, e.g. ImageNet, etc.' )
284
+ flags .DEFINE_integer (
285
+ 'log_steps' ,
286
+ default = 100 ,
287
+ help = 'The interval of steps between logging of batch level stats.' )
268
288
269
289
270
290
def serialize_config (params : base_configs .ExperimentConfig ,
@@ -307,11 +327,13 @@ def train_and_eval(
307
327
train_steps = params .train .steps or train_builder .num_steps
308
328
validation_steps = params .evaluation .steps or validation_builder .num_steps
309
329
330
+ initialize (params , train_builder )
331
+
310
332
logging .info ('Global batch size: %d' , train_builder .global_batch_size )
311
333
312
334
with strategy_scope :
313
335
model_params = params .model .model_params .as_dict ()
314
- model = MODELS [params .model .name ](** model_params )
336
+ model = get_models () [params .model .name ](** model_params )
315
337
learning_rate = optimizer_factory .build_learning_rate (
316
338
params = params .model .learning_rate ,
317
339
batch_size = train_builder .global_batch_size ,
@@ -331,8 +353,7 @@ def train_and_eval(
331
353
loss_obj = tf .keras .losses .SparseCategoricalCrossentropy ()
332
354
model .compile (optimizer = optimizer ,
333
355
loss = loss_obj ,
334
- metrics = metrics ,
335
- run_eagerly = params .runtime .enable_eager )
356
+ metrics = metrics )
336
357
337
358
initial_epoch = 0
338
359
if params .train .resume_checkpoint :
@@ -345,26 +366,37 @@ def train_and_eval(
345
366
callbacks = custom_callbacks .get_callbacks (
346
367
model_checkpoint = params .train .callbacks .enable_checkpoint_and_export ,
347
368
include_tensorboard = params .train .callbacks .enable_tensorboard ,
369
+ time_history = params .train .callbacks .enable_time_history ,
348
370
track_lr = params .train .tensorboard .track_lr ,
349
371
write_model_weights = params .train .tensorboard .write_model_weights ,
350
372
initial_step = initial_epoch * train_steps ,
373
+ batch_size = train_builder .global_batch_size ,
374
+ log_steps = params .train .time_history .log_steps ,
351
375
model_dir = params .model_dir )
352
376
377
+ if params .evaluation .skip_eval :
378
+ validation_kwargs = {}
379
+ else :
380
+ validation_kwargs = {
381
+ 'validation_data' : validation_dataset ,
382
+ 'validation_steps' : validation_steps ,
383
+ 'validation_freq' : params .evaluation .epochs_between_evals ,
384
+ }
385
+
353
386
history = model .fit (
354
387
train_dataset ,
355
388
epochs = train_epochs ,
356
389
steps_per_epoch = train_steps ,
357
390
initial_epoch = initial_epoch ,
358
391
callbacks = callbacks ,
359
- validation_data = validation_dataset ,
360
- validation_steps = validation_steps ,
361
- validation_freq = params .evaluation .epochs_between_evals )
392
+ ** validation_kwargs )
362
393
363
- validation_output = model .evaluate (
364
- validation_dataset , steps = validation_steps , verbose = 2 )
394
+ validation_output = None
395
+ if not params .evaluation .skip_eval :
396
+ validation_output = model .evaluate (
397
+ validation_dataset , steps = validation_steps , verbose = 2 )
365
398
366
399
# TODO(dankondratyuk): eval and save final test accuracy
367
-
368
400
stats = common .build_stats (history ,
369
401
validation_output ,
370
402
callbacks )
@@ -375,7 +407,7 @@ def export(params: base_configs.ExperimentConfig):
375
407
"""Runs the model export functionality."""
376
408
logging .info ('Exporting model.' )
377
409
model_params = params .model .model_params .as_dict ()
378
- model = MODELS [params .model .name ](** model_params )
410
+ model = get_models () [params .model .name ](** model_params )
379
411
checkpoint = params .export .checkpoint
380
412
if checkpoint is None :
381
413
logging .info ('No export checkpoint was provided. Using the latest '
@@ -398,8 +430,6 @@ def run(flags_obj: flags.FlagValues,
398
430
Dictionary of training/eval stats
399
431
"""
400
432
params = _get_params_from_flags (flags_obj )
401
- initialize (params )
402
-
403
433
if params .mode == 'train_and_eval' :
404
434
return train_and_eval (params , strategy_override )
405
435
elif params .mode == 'export_only' :
0 commit comments