@@ -163,7 +163,7 @@ def get_runtime_options(config: ExperimentConfig):
163
163
xla_options = {}
164
164
if config .runtime .tpu_enable_xla_dynamic_padder is not None :
165
165
xla_options ["enable_xla_dynamic_padder" ] = (
166
- config .runtime .enable_xla_dynamic_padder )
166
+ config .runtime .tpu_enable_xla_dynamic_padder )
167
167
return tf .distribute .RunOptions (
168
168
experimental_xla_options = tf .tpu .XLAOptions (** xla_options ))
169
169
@@ -205,6 +205,8 @@ def __init__(self,
205
205
self ._optimizer = optimizer
206
206
self ._checkpoint_exporter = checkpoint_exporter
207
207
self ._recovery = None
208
+ # Runtime options are only applied to train_step.
209
+ # We use default for eval_step.
208
210
self ._runtime_options = get_runtime_options (config )
209
211
210
212
# Creates a shadow copy of the weights to store weights moving average.
@@ -407,8 +409,7 @@ def step_fn(inputs):
407
409
self ._validation_loss .update_state (logs [self .task .loss ])
408
410
return logs
409
411
410
- distributed_outputs = self .strategy .run (
411
- step_fn , args = (next (iterator ),), options = self ._runtime_options )
412
+ distributed_outputs = self .strategy .run (step_fn , args = (next (iterator ),))
412
413
return tf .nest .map_structure (self .strategy .experimental_local_results ,
413
414
distributed_outputs )
414
415
0 commit comments