Skip to content

Commit a8b6963

Browse files
rachellj218tensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 272915002
1 parent b045ce7 commit a8b6963

File tree

3 files changed

+164
-98
lines changed

3 files changed

+164
-98
lines changed

official/resnet/ctl/ctl_imagenet_main.py

Lines changed: 151 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,20 @@ def build_stats(train_result, eval_result, time_callback):
7171
def get_input_dataset(flags_obj, strategy):
7272
"""Returns the test and train input datasets."""
7373
dtype = flags_core.get_tf_dtype(flags_obj)
74+
use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy)
75+
batch_size = flags_obj.batch_size
76+
if use_dataset_fn:
77+
if batch_size % strategy.num_replicas_in_sync != 0:
78+
raise ValueError(
79+
'Batch size must be divisible by number of replicas : {}'.format(
80+
strategy.num_replicas_in_sync))
81+
82+
# As auto rebatching is not supported in
83+
# `experimental_distribute_datasets_from_function()` API, which is
84+
# required when cloning dataset to multiple workers in eager mode,
85+
# we use per-replica batch size.
86+
batch_size = int(batch_size / strategy.num_replicas_in_sync)
87+
7488
if flags_obj.use_synthetic_data:
7589
input_fn = common.get_synth_input_fn(
7690
height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
@@ -82,34 +96,51 @@ def get_input_dataset(flags_obj, strategy):
8296
else:
8397
input_fn = imagenet_preprocessing.input_fn
8498

85-
train_ds = input_fn(
86-
is_training=True,
87-
data_dir=flags_obj.data_dir,
88-
batch_size=flags_obj.batch_size,
89-
parse_record_fn=imagenet_preprocessing.parse_record,
90-
datasets_num_private_threads=flags_obj.datasets_num_private_threads,
91-
dtype=dtype)
99+
def _train_dataset_fn(ctx=None):
100+
train_ds = input_fn(
101+
is_training=True,
102+
data_dir=flags_obj.data_dir,
103+
batch_size=batch_size,
104+
parse_record_fn=imagenet_preprocessing.parse_record,
105+
datasets_num_private_threads=flags_obj.datasets_num_private_threads,
106+
dtype=dtype,
107+
input_context=ctx,
108+
drop_remainder=True)
109+
return train_ds
92110

93111
if strategy:
94-
train_ds = strategy.experimental_distribute_dataset(train_ds)
112+
if isinstance(strategy, tf.distribute.experimental.TPUStrategy):
113+
train_ds = strategy.experimental_distribute_datasets_from_function(_train_dataset_fn)
114+
else:
115+
train_ds = strategy.experimental_distribute_dataset(_train_dataset_fn())
116+
else:
117+
train_ds = _train_dataset_fn()
95118

96119
test_ds = None
97120
if not flags_obj.skip_eval:
98-
test_ds = input_fn(
99-
is_training=False,
100-
data_dir=flags_obj.data_dir,
101-
batch_size=flags_obj.batch_size,
102-
parse_record_fn=imagenet_preprocessing.parse_record,
103-
dtype=dtype)
121+
def _test_data_fn(ctx=None):
122+
test_ds = input_fn(
123+
is_training=False,
124+
data_dir=flags_obj.data_dir,
125+
batch_size=batch_size,
126+
parse_record_fn=imagenet_preprocessing.parse_record,
127+
dtype=dtype,
128+
input_context=ctx)
129+
return test_ds
104130

105-
if strategy:
106-
test_ds = strategy.experimental_distribute_dataset(test_ds)
131+
if strategy:
132+
if isinstance(strategy, tf.distribute.experimental.TPUStrategy):
133+
test_ds = strategy.experimental_distribute_datasets_from_function(_test_data_fn)
134+
else:
135+
test_ds = strategy.experimental_distribute_dataset(_test_data_fn())
136+
else:
137+
test_ds = _test_data_fn()
107138

108139
return train_ds, test_ds
109140

110141

111142
def get_num_train_iterations(flags_obj):
112-
"""Returns the number of training stesps, train and test epochs."""
143+
"""Returns the number of training steps, train and test epochs."""
113144
train_steps = (
114145
imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
115146
train_epochs = flags_obj.train_epochs
@@ -124,6 +155,15 @@ def get_num_train_iterations(flags_obj):
124155
return train_steps, train_epochs, eval_steps
125156

126157

158+
def _steps_to_run(steps_in_current_epoch, steps_per_epoch, steps_per_loop):
159+
"""Calculates steps to run on device."""
160+
if steps_per_loop <= 0:
161+
raise ValueError('steps_per_loop should be positive integer.')
162+
if steps_per_loop == 1:
163+
return steps_per_loop
164+
return min(steps_per_loop, steps_per_epoch - steps_in_current_epoch)
165+
166+
127167
def run(flags_obj):
128168
"""Run ResNet ImageNet training and eval loop using custom training loops.
129169
@@ -152,33 +192,45 @@ def run(flags_obj):
152192
num_gpus=flags_obj.num_gpus,
153193
num_workers=distribution_utils.configure_cluster(),
154194
all_reduce_alg=flags_obj.all_reduce_alg,
155-
num_packs=flags_obj.num_packs)
195+
num_packs=flags_obj.num_packs,
196+
tpu_address=flags_obj.tpu)
156197

157198
train_ds, test_ds = get_input_dataset(flags_obj, strategy)
158-
train_steps, train_epochs, eval_steps = get_num_train_iterations(flags_obj)
199+
per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations(
200+
flags_obj)
201+
steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps)
202+
logging.info("Training %d epochs, each epoch has %d steps, "
203+
"total steps: %d; Eval %d steps",
204+
train_epochs, per_epoch_steps, train_epochs * per_epoch_steps,
205+
eval_steps)
159206

160207
time_callback = keras_utils.TimeHistory(flags_obj.batch_size,
161208
flags_obj.log_steps)
162209

163-
strategy_scope = distribution_utils.get_strategy_scope(strategy)
164-
with strategy_scope:
210+
with distribution_utils.get_strategy_scope(strategy):
165211
model = resnet_model.resnet50(
166212
num_classes=imagenet_preprocessing.NUM_CLASSES,
167213
batch_size=flags_obj.batch_size,
168214
use_l2_regularizer=not flags_obj.single_l2_loss_op)
169215

170-
optimizer = tf.keras.optimizers.SGD(
171-
learning_rate=common.BASE_LEARNING_RATE, momentum=0.9,
172-
nesterov=True)
173-
174-
if flags_obj.fp16_implementation == "graph_rewrite":
216+
lr_schedule = common.PiecewiseConstantDecayWithWarmup(
217+
batch_size=flags_obj.batch_size,
218+
epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
219+
warmup_epochs=common.LR_SCHEDULE[0][1],
220+
boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
221+
multipliers=list(p[0] for p in common.LR_SCHEDULE),
222+
compute_lr_on_cpu=True)
223+
optimizer = common.get_optimizer(lr_schedule)
224+
225+
if flags_obj.fp16_implementation == 'graph_rewrite':
175226
if not flags_obj.use_tf_function:
176-
raise ValueError("--fp16_implementation=graph_rewrite requires "
177-
"--use_tf_function to be true")
227+
raise ValueError('--fp16_implementation=graph_rewrite requires '
228+
'--use_tf_function to be true')
178229
loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128)
179230
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
180231
optimizer, loss_scale)
181232

233+
train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
182234
training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
183235
'training_accuracy', dtype=tf.float32)
184236
test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
@@ -187,55 +239,56 @@ def run(flags_obj):
187239

188240
trainable_variables = model.trainable_variables
189241

190-
def train_step(train_ds_inputs):
191-
"""Training StepFn."""
192-
def step_fn(inputs):
193-
"""Per-Replica StepFn."""
194-
images, labels = inputs
195-
with tf.GradientTape() as tape:
196-
logits = model(images, training=True)
197-
198-
prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
199-
labels, logits)
200-
loss = tf.reduce_sum(prediction_loss) * (1.0/ flags_obj.batch_size)
201-
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
202-
203-
if flags_obj.single_l2_loss_op:
204-
filtered_variables = [
205-
tf.reshape(v, (-1,))
206-
for v in trainable_variables
207-
if 'bn' not in v.name
208-
]
209-
l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.nn.l2_loss(
210-
tf.concat(filtered_variables, axis=0))
211-
loss += (l2_loss / num_replicas)
212-
else:
213-
loss += (tf.reduce_sum(model.losses) / num_replicas)
214-
215-
# Scale the loss
216-
if flags_obj.dtype == "fp16":
217-
loss = optimizer.get_scaled_loss(loss)
218-
219-
grads = tape.gradient(loss, trainable_variables)
220-
221-
# Unscale the grads
242+
def step_fn(inputs):
243+
"""Per-Replica StepFn."""
244+
images, labels = inputs
245+
with tf.GradientTape() as tape:
246+
logits = model(images, training=True)
247+
248+
prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
249+
labels, logits)
250+
loss = tf.reduce_sum(prediction_loss) * (1.0/ flags_obj.batch_size)
251+
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
252+
253+
if flags_obj.single_l2_loss_op:
254+
filtered_variables = [
255+
tf.reshape(v, (-1,))
256+
for v in trainable_variables
257+
if 'bn' not in v.name
258+
]
259+
l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.nn.l2_loss(
260+
tf.concat(filtered_variables, axis=0))
261+
loss += (l2_loss / num_replicas)
262+
else:
263+
loss += (tf.reduce_sum(model.losses) / num_replicas)
264+
265+
# Scale the loss
222266
if flags_obj.dtype == "fp16":
223-
grads = optimizer.get_unscaled_gradients(grads)
267+
loss = optimizer.get_scaled_loss(loss)
224268

225-
optimizer.apply_gradients(zip(grads, trainable_variables))
269+
grads = tape.gradient(loss, trainable_variables)
226270

227-
training_accuracy.update_state(labels, logits)
228-
return loss
271+
# Unscale the grads
272+
if flags_obj.dtype == "fp16":
273+
grads = optimizer.get_unscaled_gradients(grads)
229274

275+
optimizer.apply_gradients(zip(grads, trainable_variables))
276+
train_loss.update_state(loss)
277+
training_accuracy.update_state(labels, logits)
278+
279+
@tf.function
280+
def train_steps(iterator, steps):
281+
"""Performs distributed training steps in a loop."""
282+
for _ in tf.range(steps):
283+
strategy.experimental_run_v2(step_fn, args=(next(iterator),))
284+
285+
def train_single_step(iterator):
230286
if strategy:
231-
per_replica_losses = strategy.experimental_run_v2(
232-
step_fn, args=(train_ds_inputs,))
233-
return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
234-
axis=None)
287+
strategy.experimental_run_v2(step_fn, args=(next(iterator),))
235288
else:
236-
return step_fn(train_ds_inputs)
289+
return step_fn(next(iterator))
237290

238-
def test_step(test_ds_inputs):
291+
def test_step(iterator):
239292
"""Evaluation StepFn."""
240293
def step_fn(inputs):
241294
images, labels = inputs
@@ -247,34 +300,39 @@ def step_fn(inputs):
247300
test_accuracy.update_state(labels, logits)
248301

249302
if strategy:
250-
strategy.experimental_run_v2(step_fn, args=(test_ds_inputs,))
303+
strategy.experimental_run_v2(step_fn, args=(next(iterator),))
251304
else:
252-
step_fn(test_ds_inputs)
305+
step_fn(next(iterator))
253306

254307
if flags_obj.use_tf_function:
255-
train_step = tf.function(train_step)
308+
train_single_step = tf.function(train_single_step)
256309
test_step = tf.function(test_step)
257310

311+
train_iter = iter(train_ds)
258312
time_callback.on_train_begin()
259313
for epoch in range(train_epochs):
260-
261-
train_iter = iter(train_ds)
262-
total_loss = 0.0
314+
train_loss.reset_states()
263315
training_accuracy.reset_states()
264316

265-
for step in range(train_steps):
266-
optimizer.lr = common.learning_rate_schedule(
267-
epoch, step, train_steps, flags_obj.batch_size)
268-
269-
time_callback.on_batch_begin(step+epoch*train_steps)
270-
total_loss += train_step(next(train_iter))
271-
time_callback.on_batch_end(step+epoch*train_steps)
272-
273-
train_loss = total_loss / train_steps
274-
logging.info('Training loss: %s, accuracy: %s%% at epoch: %d',
275-
train_loss.numpy(),
317+
steps_in_current_epoch = 0
318+
while steps_in_current_epoch < per_epoch_steps:
319+
time_callback.on_batch_begin(
320+
steps_in_current_epoch+epoch*per_epoch_steps)
321+
steps = _steps_to_run(steps_in_current_epoch, per_epoch_steps,
322+
steps_per_loop)
323+
if steps == 1:
324+
train_single_step(train_iter)
325+
else:
326+
# Converts steps to a Tensor to avoid tf.function retracing.
327+
train_steps(train_iter, tf.convert_to_tensor(steps, dtype=tf.int32))
328+
time_callback.on_batch_end(
329+
steps_in_current_epoch+epoch*per_epoch_steps)
330+
steps_in_current_epoch += steps
331+
332+
logging.info('Training loss: %s, accuracy: %s%% at epoch %d',
333+
train_loss.result().numpy(),
276334
training_accuracy.result().numpy(),
277-
epoch)
335+
epoch + 1)
278336

279337
if (not flags_obj.skip_eval and
280338
(epoch + 1) % flags_obj.epochs_between_evals == 0):
@@ -283,12 +341,12 @@ def step_fn(inputs):
283341

284342
test_iter = iter(test_ds)
285343
for _ in range(eval_steps):
286-
test_step(next(test_iter))
344+
test_step(test_iter)
287345

288346
logging.info('Test loss: %s, accuracy: %s%% at epoch: %d',
289347
test_loss.result().numpy(),
290348
test_accuracy.result().numpy(),
291-
epoch)
349+
epoch + 1)
292350

293351
time_callback.on_train_end()
294352

@@ -297,7 +355,7 @@ def step_fn(inputs):
297355
if not flags_obj.skip_eval:
298356
eval_result = [test_loss.result().numpy(),
299357
test_accuracy.result().numpy()]
300-
train_result = [train_loss.numpy(),
358+
train_result = [train_loss.result().numpy(),
301359
training_accuracy.result().numpy()]
302360

303361
stats = build_stats(train_result, eval_result, time_callback)
@@ -307,7 +365,8 @@ def step_fn(inputs):
307365
def main(_):
308366
model_helpers.apply_clean(flags.FLAGS)
309367
with logger.benchmark_context(flags.FLAGS):
310-
return run(flags.FLAGS)
368+
stats = run(flags.FLAGS)
369+
logging.info('Run stats:\n%s', stats)
311370

312371

313372
if __name__ == '__main__':

official/vision/image_classification/common.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,13 @@ def define_keras_flags(dynamic_loss_scale=True):
353353
flags.DEFINE_boolean(
354354
name='enable_checkpoint_and_export', default=False,
355355
help='Whether to enable a checkpoint callback and export the savedmodel.')
356+
flags.DEFINE_string(
357+
name='tpu', default='', help='TPU address to connect to.')
358+
flags.DEFINE_integer(
359+
name='steps_per_loop', default=1,
360+
help='Number of steps per graph-mode loop. Only training step happens '
361+
'inside the loop. Callbacks will not be called inside. Will be capped at '
362+
'steps per epoch.')
356363

357364

358365
def get_synth_input_fn(height, width, num_channels, num_classes,

official/vision/image_classification/imagenet_preprocessing.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,9 @@ def process_record_dataset(dataset,
115115
if is_training:
116116
# Shuffles records before repeating to respect epoch boundaries.
117117
dataset = dataset.shuffle(buffer_size=shuffle_buffer)
118+
# Repeats the dataset for the number of epochs to train.
119+
dataset = dataset.repeat()
118120

119-
# Repeats the dataset for the number of epochs to train.
120-
dataset = dataset.repeat(num_epochs)
121121

122122
# Parses the raw records into images and labels.
123123
dataset = dataset.map(
@@ -133,10 +133,10 @@ def process_record_dataset(dataset,
133133
# on how many devices are present.
134134
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
135135

136-
if tf_data_experimental_slack:
137-
options = tf.data.Options()
138-
options.experimental_slack = True
139-
dataset = dataset.with_options(options)
136+
options = tf.data.Options()
137+
options.experimental_slack = tf_data_experimental_slack
138+
options.experimental_allow_stateful = True
139+
dataset = dataset.with_options(options)
140140

141141
return dataset
142142

0 commit comments

Comments
 (0)