@@ -71,6 +71,20 @@ def build_stats(train_result, eval_result, time_callback):
71
71
def get_input_dataset (flags_obj , strategy ):
72
72
"""Returns the test and train input datasets."""
73
73
dtype = flags_core .get_tf_dtype (flags_obj )
74
+ use_dataset_fn = isinstance (strategy , tf .distribute .experimental .TPUStrategy )
75
+ batch_size = flags_obj .batch_size
76
+ if use_dataset_fn :
77
+ if batch_size % strategy .num_replicas_in_sync != 0 :
78
+ raise ValueError (
79
+ 'Batch size must be divisible by number of replicas : {}' .format (
80
+ strategy .num_replicas_in_sync ))
81
+
82
+ # As auto rebatching is not supported in
83
+ # `experimental_distribute_datasets_from_function()` API, which is
84
+ # required when cloning dataset to multiple workers in eager mode,
85
+ # we use per-replica batch size.
86
+ batch_size = int (batch_size / strategy .num_replicas_in_sync )
87
+
74
88
if flags_obj .use_synthetic_data :
75
89
input_fn = common .get_synth_input_fn (
76
90
height = imagenet_preprocessing .DEFAULT_IMAGE_SIZE ,
@@ -82,34 +96,51 @@ def get_input_dataset(flags_obj, strategy):
82
96
else :
83
97
input_fn = imagenet_preprocessing .input_fn
84
98
85
- train_ds = input_fn (
86
- is_training = True ,
87
- data_dir = flags_obj .data_dir ,
88
- batch_size = flags_obj .batch_size ,
89
- parse_record_fn = imagenet_preprocessing .parse_record ,
90
- datasets_num_private_threads = flags_obj .datasets_num_private_threads ,
91
- dtype = dtype )
99
+ def _train_dataset_fn (ctx = None ):
100
+ train_ds = input_fn (
101
+ is_training = True ,
102
+ data_dir = flags_obj .data_dir ,
103
+ batch_size = batch_size ,
104
+ parse_record_fn = imagenet_preprocessing .parse_record ,
105
+ datasets_num_private_threads = flags_obj .datasets_num_private_threads ,
106
+ dtype = dtype ,
107
+ input_context = ctx ,
108
+ drop_remainder = True )
109
+ return train_ds
92
110
93
111
if strategy :
94
- train_ds = strategy .experimental_distribute_dataset (train_ds )
112
+ if isinstance (strategy , tf .distribute .experimental .TPUStrategy ):
113
+ train_ds = strategy .experimental_distribute_datasets_from_function (_train_dataset_fn )
114
+ else :
115
+ train_ds = strategy .experimental_distribute_dataset (_train_dataset_fn ())
116
+ else :
117
+ train_ds = _train_dataset_fn ()
95
118
96
119
test_ds = None
97
120
if not flags_obj .skip_eval :
98
- test_ds = input_fn (
99
- is_training = False ,
100
- data_dir = flags_obj .data_dir ,
101
- batch_size = flags_obj .batch_size ,
102
- parse_record_fn = imagenet_preprocessing .parse_record ,
103
- dtype = dtype )
121
+ def _test_data_fn (ctx = None ):
122
+ test_ds = input_fn (
123
+ is_training = False ,
124
+ data_dir = flags_obj .data_dir ,
125
+ batch_size = batch_size ,
126
+ parse_record_fn = imagenet_preprocessing .parse_record ,
127
+ dtype = dtype ,
128
+ input_context = ctx )
129
+ return test_ds
104
130
105
- if strategy :
106
- test_ds = strategy .experimental_distribute_dataset (test_ds )
131
+ if strategy :
132
+ if isinstance (strategy , tf .distribute .experimental .TPUStrategy ):
133
+ test_ds = strategy .experimental_distribute_datasets_from_function (_test_data_fn )
134
+ else :
135
+ test_ds = strategy .experimental_distribute_dataset (_test_data_fn ())
136
+ else :
137
+ test_ds = _test_data_fn ()
107
138
108
139
return train_ds , test_ds
109
140
110
141
111
142
def get_num_train_iterations (flags_obj ):
112
- """Returns the number of training stesps , train and test epochs."""
143
+ """Returns the number of training steps , train and test epochs."""
113
144
train_steps = (
114
145
imagenet_preprocessing .NUM_IMAGES ['train' ] // flags_obj .batch_size )
115
146
train_epochs = flags_obj .train_epochs
@@ -124,6 +155,15 @@ def get_num_train_iterations(flags_obj):
124
155
return train_steps , train_epochs , eval_steps
125
156
126
157
158
+ def _steps_to_run (steps_in_current_epoch , steps_per_epoch , steps_per_loop ):
159
+ """Calculates steps to run on device."""
160
+ if steps_per_loop <= 0 :
161
+ raise ValueError ('steps_per_loop should be positive integer.' )
162
+ if steps_per_loop == 1 :
163
+ return steps_per_loop
164
+ return min (steps_per_loop , steps_per_epoch - steps_in_current_epoch )
165
+
166
+
127
167
def run (flags_obj ):
128
168
"""Run ResNet ImageNet training and eval loop using custom training loops.
129
169
@@ -152,33 +192,45 @@ def run(flags_obj):
152
192
num_gpus = flags_obj .num_gpus ,
153
193
num_workers = distribution_utils .configure_cluster (),
154
194
all_reduce_alg = flags_obj .all_reduce_alg ,
155
- num_packs = flags_obj .num_packs )
195
+ num_packs = flags_obj .num_packs ,
196
+ tpu_address = flags_obj .tpu )
156
197
157
198
train_ds , test_ds = get_input_dataset (flags_obj , strategy )
158
- train_steps , train_epochs , eval_steps = get_num_train_iterations (flags_obj )
199
+ per_epoch_steps , train_epochs , eval_steps = get_num_train_iterations (
200
+ flags_obj )
201
+ steps_per_loop = min (flags_obj .steps_per_loop , per_epoch_steps )
202
+ logging .info ("Training %d epochs, each epoch has %d steps, "
203
+ "total steps: %d; Eval %d steps" ,
204
+ train_epochs , per_epoch_steps , train_epochs * per_epoch_steps ,
205
+ eval_steps )
159
206
160
207
time_callback = keras_utils .TimeHistory (flags_obj .batch_size ,
161
208
flags_obj .log_steps )
162
209
163
- strategy_scope = distribution_utils .get_strategy_scope (strategy )
164
- with strategy_scope :
210
+ with distribution_utils .get_strategy_scope (strategy ):
165
211
model = resnet_model .resnet50 (
166
212
num_classes = imagenet_preprocessing .NUM_CLASSES ,
167
213
batch_size = flags_obj .batch_size ,
168
214
use_l2_regularizer = not flags_obj .single_l2_loss_op )
169
215
170
- optimizer = tf .keras .optimizers .SGD (
171
- learning_rate = common .BASE_LEARNING_RATE , momentum = 0.9 ,
172
- nesterov = True )
173
-
174
- if flags_obj .fp16_implementation == "graph_rewrite" :
216
+ lr_schedule = common .PiecewiseConstantDecayWithWarmup (
217
+ batch_size = flags_obj .batch_size ,
218
+ epoch_size = imagenet_preprocessing .NUM_IMAGES ['train' ],
219
+ warmup_epochs = common .LR_SCHEDULE [0 ][1 ],
220
+ boundaries = list (p [1 ] for p in common .LR_SCHEDULE [1 :]),
221
+ multipliers = list (p [0 ] for p in common .LR_SCHEDULE ),
222
+ compute_lr_on_cpu = True )
223
+ optimizer = common .get_optimizer (lr_schedule )
224
+
225
+ if flags_obj .fp16_implementation == 'graph_rewrite' :
175
226
if not flags_obj .use_tf_function :
176
- raise ValueError (" --fp16_implementation=graph_rewrite requires "
177
- " --use_tf_function to be true" )
227
+ raise ValueError (' --fp16_implementation=graph_rewrite requires '
228
+ ' --use_tf_function to be true' )
178
229
loss_scale = flags_core .get_loss_scale (flags_obj , default_for_fp16 = 128 )
179
230
optimizer = tf .train .experimental .enable_mixed_precision_graph_rewrite (
180
231
optimizer , loss_scale )
181
232
233
+ train_loss = tf .keras .metrics .Mean ('train_loss' , dtype = tf .float32 )
182
234
training_accuracy = tf .keras .metrics .SparseCategoricalAccuracy (
183
235
'training_accuracy' , dtype = tf .float32 )
184
236
test_loss = tf .keras .metrics .Mean ('test_loss' , dtype = tf .float32 )
@@ -187,55 +239,56 @@ def run(flags_obj):
187
239
188
240
trainable_variables = model .trainable_variables
189
241
190
- def train_step (train_ds_inputs ):
191
- """Training StepFn."""
192
- def step_fn (inputs ):
193
- """Per-Replica StepFn."""
194
- images , labels = inputs
195
- with tf .GradientTape () as tape :
196
- logits = model (images , training = True )
197
-
198
- prediction_loss = tf .keras .losses .sparse_categorical_crossentropy (
199
- labels , logits )
200
- loss = tf .reduce_sum (prediction_loss ) * (1.0 / flags_obj .batch_size )
201
- num_replicas = tf .distribute .get_strategy ().num_replicas_in_sync
202
-
203
- if flags_obj .single_l2_loss_op :
204
- filtered_variables = [
205
- tf .reshape (v , (- 1 ,))
206
- for v in trainable_variables
207
- if 'bn' not in v .name
208
- ]
209
- l2_loss = resnet_model .L2_WEIGHT_DECAY * 2 * tf .nn .l2_loss (
210
- tf .concat (filtered_variables , axis = 0 ))
211
- loss += (l2_loss / num_replicas )
212
- else :
213
- loss += (tf .reduce_sum (model .losses ) / num_replicas )
214
-
215
- # Scale the loss
216
- if flags_obj .dtype == "fp16" :
217
- loss = optimizer .get_scaled_loss (loss )
218
-
219
- grads = tape .gradient (loss , trainable_variables )
220
-
221
- # Unscale the grads
242
+ def step_fn (inputs ):
243
+ """Per-Replica StepFn."""
244
+ images , labels = inputs
245
+ with tf .GradientTape () as tape :
246
+ logits = model (images , training = True )
247
+
248
+ prediction_loss = tf .keras .losses .sparse_categorical_crossentropy (
249
+ labels , logits )
250
+ loss = tf .reduce_sum (prediction_loss ) * (1.0 / flags_obj .batch_size )
251
+ num_replicas = tf .distribute .get_strategy ().num_replicas_in_sync
252
+
253
+ if flags_obj .single_l2_loss_op :
254
+ filtered_variables = [
255
+ tf .reshape (v , (- 1 ,))
256
+ for v in trainable_variables
257
+ if 'bn' not in v .name
258
+ ]
259
+ l2_loss = resnet_model .L2_WEIGHT_DECAY * 2 * tf .nn .l2_loss (
260
+ tf .concat (filtered_variables , axis = 0 ))
261
+ loss += (l2_loss / num_replicas )
262
+ else :
263
+ loss += (tf .reduce_sum (model .losses ) / num_replicas )
264
+
265
+ # Scale the loss
222
266
if flags_obj .dtype == "fp16" :
223
- grads = optimizer .get_unscaled_gradients ( grads )
267
+ loss = optimizer .get_scaled_loss ( loss )
224
268
225
- optimizer . apply_gradients ( zip ( grads , trainable_variables ) )
269
+ grads = tape . gradient ( loss , trainable_variables )
226
270
227
- training_accuracy .update_state (labels , logits )
228
- return loss
271
+ # Unscale the grads
272
+ if flags_obj .dtype == "fp16" :
273
+ grads = optimizer .get_unscaled_gradients (grads )
229
274
275
+ optimizer .apply_gradients (zip (grads , trainable_variables ))
276
+ train_loss .update_state (loss )
277
+ training_accuracy .update_state (labels , logits )
278
+
279
+ @tf .function
280
+ def train_steps (iterator , steps ):
281
+ """Performs distributed training steps in a loop."""
282
+ for _ in tf .range (steps ):
283
+ strategy .experimental_run_v2 (step_fn , args = (next (iterator ),))
284
+
285
+ def train_single_step (iterator ):
230
286
if strategy :
231
- per_replica_losses = strategy .experimental_run_v2 (
232
- step_fn , args = (train_ds_inputs ,))
233
- return strategy .reduce (tf .distribute .ReduceOp .SUM , per_replica_losses ,
234
- axis = None )
287
+ strategy .experimental_run_v2 (step_fn , args = (next (iterator ),))
235
288
else :
236
- return step_fn (train_ds_inputs )
289
+ return step_fn (next ( iterator ) )
237
290
238
- def test_step (test_ds_inputs ):
291
+ def test_step (iterator ):
239
292
"""Evaluation StepFn."""
240
293
def step_fn (inputs ):
241
294
images , labels = inputs
@@ -247,34 +300,39 @@ def step_fn(inputs):
247
300
test_accuracy .update_state (labels , logits )
248
301
249
302
if strategy :
250
- strategy .experimental_run_v2 (step_fn , args = (test_ds_inputs ,))
303
+ strategy .experimental_run_v2 (step_fn , args = (next ( iterator ) ,))
251
304
else :
252
- step_fn (test_ds_inputs )
305
+ step_fn (next ( iterator ) )
253
306
254
307
if flags_obj .use_tf_function :
255
- train_step = tf .function (train_step )
308
+ train_single_step = tf .function (train_single_step )
256
309
test_step = tf .function (test_step )
257
310
311
+ train_iter = iter (train_ds )
258
312
time_callback .on_train_begin ()
259
313
for epoch in range (train_epochs ):
260
-
261
- train_iter = iter (train_ds )
262
- total_loss = 0.0
314
+ train_loss .reset_states ()
263
315
training_accuracy .reset_states ()
264
316
265
- for step in range (train_steps ):
266
- optimizer .lr = common .learning_rate_schedule (
267
- epoch , step , train_steps , flags_obj .batch_size )
268
-
269
- time_callback .on_batch_begin (step + epoch * train_steps )
270
- total_loss += train_step (next (train_iter ))
271
- time_callback .on_batch_end (step + epoch * train_steps )
272
-
273
- train_loss = total_loss / train_steps
274
- logging .info ('Training loss: %s, accuracy: %s%% at epoch: %d' ,
275
- train_loss .numpy (),
317
+ steps_in_current_epoch = 0
318
+ while steps_in_current_epoch < per_epoch_steps :
319
+ time_callback .on_batch_begin (
320
+ steps_in_current_epoch + epoch * per_epoch_steps )
321
+ steps = _steps_to_run (steps_in_current_epoch , per_epoch_steps ,
322
+ steps_per_loop )
323
+ if steps == 1 :
324
+ train_single_step (train_iter )
325
+ else :
326
+ # Converts steps to a Tensor to avoid tf.function retracing.
327
+ train_steps (train_iter , tf .convert_to_tensor (steps , dtype = tf .int32 ))
328
+ time_callback .on_batch_end (
329
+ steps_in_current_epoch + epoch * per_epoch_steps )
330
+ steps_in_current_epoch += steps
331
+
332
+ logging .info ('Training loss: %s, accuracy: %s%% at epoch %d' ,
333
+ train_loss .result ().numpy (),
276
334
training_accuracy .result ().numpy (),
277
- epoch )
335
+ epoch + 1 )
278
336
279
337
if (not flags_obj .skip_eval and
280
338
(epoch + 1 ) % flags_obj .epochs_between_evals == 0 ):
@@ -283,12 +341,12 @@ def step_fn(inputs):
283
341
284
342
test_iter = iter (test_ds )
285
343
for _ in range (eval_steps ):
286
- test_step (next ( test_iter ) )
344
+ test_step (test_iter )
287
345
288
346
logging .info ('Test loss: %s, accuracy: %s%% at epoch: %d' ,
289
347
test_loss .result ().numpy (),
290
348
test_accuracy .result ().numpy (),
291
- epoch )
349
+ epoch + 1 )
292
350
293
351
time_callback .on_train_end ()
294
352
@@ -297,7 +355,7 @@ def step_fn(inputs):
297
355
if not flags_obj .skip_eval :
298
356
eval_result = [test_loss .result ().numpy (),
299
357
test_accuracy .result ().numpy ()]
300
- train_result = [train_loss .numpy (),
358
+ train_result = [train_loss .result (). numpy (),
301
359
training_accuracy .result ().numpy ()]
302
360
303
361
stats = build_stats (train_result , eval_result , time_callback )
@@ -307,7 +365,8 @@ def step_fn(inputs):
307
365
def main (_ ):
308
366
model_helpers .apply_clean (flags .FLAGS )
309
367
with logger .benchmark_context (flags .FLAGS ):
310
- return run (flags .FLAGS )
368
+ stats = run (flags .FLAGS )
369
+ logging .info ('Run stats:\n %s' , stats )
311
370
312
371
313
372
if __name__ == '__main__' :
0 commit comments