Skip to content

Commit dfaf525

Browse files
[core] Only use runtime_options for training.
PiperOrigin-RevId: 363782489
1 parent b6d1ec2 commit dfaf525

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

official/core/base_trainer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def get_runtime_options(config: ExperimentConfig):
163163
xla_options = {}
164164
if config.runtime.tpu_enable_xla_dynamic_padder is not None:
165165
xla_options["enable_xla_dynamic_padder"] = (
166-
config.runtime.enable_xla_dynamic_padder)
166+
config.runtime.tpu_enable_xla_dynamic_padder)
167167
return tf.distribute.RunOptions(
168168
experimental_xla_options=tf.tpu.XLAOptions(**xla_options))
169169

@@ -205,6 +205,8 @@ def __init__(self,
205205
self._optimizer = optimizer
206206
self._checkpoint_exporter = checkpoint_exporter
207207
self._recovery = None
208+
# Runtime options are only applied to train_step.
209+
# We use default for eval_step.
208210
self._runtime_options = get_runtime_options(config)
209211

210212
# Creates a shadow copy of the weights to store weights moving average.
@@ -407,8 +409,7 @@ def step_fn(inputs):
407409
self._validation_loss.update_state(logs[self.task.loss])
408410
return logs
409411

410-
distributed_outputs = self.strategy.run(
411-
step_fn, args=(next(iterator),), options=self._runtime_options)
412+
distributed_outputs = self.strategy.run(step_fn, args=(next(iterator),))
412413
return tf.nest.map_structure(self.strategy.experimental_local_results,
413414
distributed_outputs)
414415

official/core/config_definitions.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,11 @@ class RuntimeConfig(base_config.Config):
140140
run_eagerly: bool = False
141141
batchnorm_spatial_persistent: bool = False
142142

143-
# XLA runtime
143+
# XLA runtime params.
144+
# XLA params are only applied to the train_step.
145+
# These augments can improve training speed. They can also improve eval, but
146+
# may reduce usability and users would need to make changes to code.
147+
144148
# Whether to enable XLA dynamic padder
145149
# infrastructure to handle dynamic shapes inputs inside XLA. True by
146150
# default. Disabling this may cause correctness issues with dynamic shapes

0 commit comments

Comments
 (0)