33
33
# pylint: enable=g-bad-import-order
34
34
35
35
from official .r1 .utils import export
36
+ from tensorflow .contrib import cluster_resolver as contrib_cluster_resolver
37
+ from tensorflow .contrib import opt as contrib_opt
38
+ from tensorflow .contrib import tpu as contrib_tpu
36
39
from official .r1 .utils import tpu as tpu_util
37
40
from official .transformer import compute_bleu
38
41
from official .transformer import translate
@@ -115,8 +118,10 @@ def model_fn(features, labels, mode, params):
115
118
metric_fn = lambda logits , labels : (
116
119
metrics .get_eval_metrics (logits , labels , params = params ))
117
120
eval_metrics = (metric_fn , [logits , labels ])
118
- return tf .contrib .tpu .TPUEstimatorSpec (
119
- mode = mode , loss = loss , predictions = {"predictions" : logits },
121
+ return contrib_tpu .TPUEstimatorSpec (
122
+ mode = mode ,
123
+ loss = loss ,
124
+ predictions = {"predictions" : logits },
120
125
eval_metrics = eval_metrics )
121
126
return tf .estimator .EstimatorSpec (
122
127
mode = mode , loss = loss , predictions = {"predictions" : logits },
@@ -128,12 +133,14 @@ def model_fn(features, labels, mode, params):
128
133
# in TensorBoard.
129
134
metric_dict ["minibatch_loss" ] = loss
130
135
if params ["use_tpu" ]:
131
- return tf .contrib .tpu .TPUEstimatorSpec (
132
- mode = mode , loss = loss , train_op = train_op ,
136
+ return contrib_tpu .TPUEstimatorSpec (
137
+ mode = mode ,
138
+ loss = loss ,
139
+ train_op = train_op ,
133
140
host_call = tpu_util .construct_scalar_host_call (
134
- metric_dict = metric_dict , model_dir = params [ "model_dir" ],
135
- prefix = "training/" )
136
- )
141
+ metric_dict = metric_dict ,
142
+ model_dir = params [ "model_dir" ],
143
+ prefix = "training/" ) )
137
144
record_scalars (metric_dict )
138
145
return tf .estimator .EstimatorSpec (mode = mode , loss = loss , train_op = train_op )
139
146
@@ -173,14 +180,14 @@ def get_train_op_and_metrics(loss, params):
173
180
174
181
# Create optimizer. Use LazyAdamOptimizer from TF contrib, which is faster
175
182
# than the TF core Adam optimizer.
176
- optimizer = tf . contrib . opt .LazyAdamOptimizer (
183
+ optimizer = contrib_opt .LazyAdamOptimizer (
177
184
learning_rate ,
178
185
beta1 = params ["optimizer_adam_beta1" ],
179
186
beta2 = params ["optimizer_adam_beta2" ],
180
187
epsilon = params ["optimizer_adam_epsilon" ])
181
188
182
189
if params ["use_tpu" ] and params ["tpu" ] != tpu_util .LOCAL :
183
- optimizer = tf . contrib . tpu .CrossShardOptimizer (optimizer )
190
+ optimizer = contrib_tpu .CrossShardOptimizer (optimizer )
184
191
185
192
# Uses automatic mixed precision FP16 training if on GPU.
186
193
if params ["dtype" ] == "fp16" :
@@ -528,31 +535,31 @@ def construct_estimator(flags_obj, params, schedule_manager):
528
535
model_fn = model_fn , model_dir = flags_obj .model_dir , params = params ,
529
536
config = tf .estimator .RunConfig (train_distribute = distribution_strategy ))
530
537
531
- tpu_cluster_resolver = tf . contrib . cluster_resolver .TPUClusterResolver (
538
+ tpu_cluster_resolver = contrib_cluster_resolver .TPUClusterResolver (
532
539
tpu = flags_obj .tpu ,
533
540
zone = flags_obj .tpu_zone ,
534
- project = flags_obj .tpu_gcp_project
535
- )
541
+ project = flags_obj .tpu_gcp_project )
536
542
537
- tpu_config = tf . contrib . tpu .TPUConfig (
543
+ tpu_config = contrib_tpu .TPUConfig (
538
544
iterations_per_loop = schedule_manager .single_iteration_train_steps ,
539
545
num_shards = flags_obj .num_tpu_shards )
540
546
541
- run_config = tf . contrib . tpu .RunConfig (
547
+ run_config = contrib_tpu .RunConfig (
542
548
cluster = tpu_cluster_resolver ,
543
549
model_dir = flags_obj .model_dir ,
544
550
session_config = tf .ConfigProto (
545
551
allow_soft_placement = True , log_device_placement = True ),
546
552
tpu_config = tpu_config )
547
553
548
- return tf . contrib . tpu .TPUEstimator (
554
+ return contrib_tpu .TPUEstimator (
549
555
model_fn = model_fn ,
550
556
use_tpu = params ["use_tpu" ] and flags_obj .tpu != tpu_util .LOCAL ,
551
557
train_batch_size = schedule_manager .batch_size ,
552
558
eval_batch_size = schedule_manager .batch_size ,
553
559
params = {
554
560
# TPUEstimator needs to populate batch_size itself due to sharding.
555
- key : value for key , value in params .items () if key != "batch_size" },
561
+ key : value for key , value in params .items () if key != "batch_size"
562
+ },
556
563
config = run_config )
557
564
558
565
0 commit comments