Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 6b31c0f

Browse files
qlzh727Mesh TensorFlow Team
authored andcommitted
Explicitly import estimator from tensorflow as a separate import instead of accessing it via tf.estimator and depend on the tensorflow estimator target.
PiperOrigin-RevId: 437342499
1 parent 58153bf commit 6b31c0f

File tree

12 files changed

+98
-86
lines changed

12 files changed

+98
-86
lines changed

examples/mnist.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import mesh_tensorflow as mtf
2626
import mnist_dataset as dataset # local file import
2727
import tensorflow.compat.v1 as tf
28+
from tensorflow.compat.v1 import estimator as tf_estimator
2829

2930

3031
tf.flags.DEFINE_string("data_dir", "/tmp/mnist_data",
@@ -126,7 +127,7 @@ def model_fn(features, labels, mode, params):
126127
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
127128
mesh_shape, layout_rules, mesh_devices)
128129

129-
if mode == tf.estimator.ModeKeys.TRAIN:
130+
if mode == tf_estimator.ModeKeys.TRAIN:
130131
var_grads = mtf.gradients(
131132
[loss], [v.outputs[0] for v in graph.trainable_variables])
132133
optimizer = mtf.optimize.AdafactorOptimizer()
@@ -136,11 +137,11 @@ def model_fn(features, labels, mode, params):
136137
restore_hook = mtf.MtfRestoreHook(lowering)
137138

138139
tf_logits = lowering.export_to_tf_tensor(logits)
139-
if mode != tf.estimator.ModeKeys.PREDICT:
140+
if mode != tf_estimator.ModeKeys.PREDICT:
140141
tf_loss = lowering.export_to_tf_tensor(loss)
141142
tf.summary.scalar("loss", tf_loss)
142143

143-
if mode == tf.estimator.ModeKeys.TRAIN:
144+
if mode == tf_estimator.ModeKeys.TRAIN:
144145
tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
145146
tf_update_ops.append(tf.assign_add(global_step, 1))
146147
train_op = tf.group(tf_update_ops)
@@ -169,25 +170,25 @@ def model_fn(features, labels, mode, params):
169170
tf.summary.scalar("train_accuracy", accuracy[1])
170171

171172
# restore_hook must come before saver_hook
172-
return tf.estimator.EstimatorSpec(
173-
tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op,
173+
return tf_estimator.EstimatorSpec(
174+
tf_estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op,
174175
training_chief_hooks=[restore_hook, saver_hook])
175176

176-
if mode == tf.estimator.ModeKeys.PREDICT:
177+
if mode == tf_estimator.ModeKeys.PREDICT:
177178
predictions = {
178179
"classes": tf.argmax(tf_logits, axis=1),
179180
"probabilities": tf.nn.softmax(tf_logits),
180181
}
181-
return tf.estimator.EstimatorSpec(
182-
mode=tf.estimator.ModeKeys.PREDICT,
182+
return tf_estimator.EstimatorSpec(
183+
mode=tf_estimator.ModeKeys.PREDICT,
183184
predictions=predictions,
184185
prediction_hooks=[restore_hook],
185186
export_outputs={
186-
"classify": tf.estimator.export.PredictOutput(predictions)
187+
"classify": tf_estimator.export.PredictOutput(predictions)
187188
})
188-
if mode == tf.estimator.ModeKeys.EVAL:
189-
return tf.estimator.EstimatorSpec(
190-
mode=tf.estimator.ModeKeys.EVAL,
189+
if mode == tf_estimator.ModeKeys.EVAL:
190+
return tf_estimator.EstimatorSpec(
191+
mode=tf_estimator.ModeKeys.EVAL,
191192
loss=tf_loss,
192193
evaluation_hooks=[restore_hook],
193194
eval_metric_ops={
@@ -199,7 +200,7 @@ def model_fn(features, labels, mode, params):
199200

200201
def run_mnist():
201202
"""Run MNIST training and eval loop."""
202-
mnist_classifier = tf.estimator.Estimator(
203+
mnist_classifier = tf_estimator.Estimator(
203204
model_fn=model_fn,
204205
model_dir=FLAGS.model_dir)
205206

examples/toy_model_tpu.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import mesh_tensorflow as mtf
2323
import numpy
2424
import tensorflow.compat.v1 as tf
25+
from tensorflow.compat.v1 import estimator as tf_estimator
2526

2627
from tensorflow.python.data.ops.dataset_ops import Dataset
2728
from tensorflow.python.platform import flags
@@ -176,7 +177,7 @@ def model_fn(features, labels, mode, params):
176177
logits, loss = toy_model(features, mesh)
177178

178179
# TRAIN mode
179-
if mode == tf.estimator.ModeKeys.TRAIN:
180+
if mode == tf_estimator.ModeKeys.TRAIN:
180181
var_grads = mtf.gradients([loss],
181182
[v.outputs[0] for v in graph.trainable_variables])
182183
if FLAGS.optimizer == 'Adafactor':
@@ -193,7 +194,7 @@ def model_fn(features, labels, mode, params):
193194

194195
tf_loss = tf.to_float(lowering.export_to_tf_tensor(loss))
195196

196-
if mode == tf.estimator.ModeKeys.TRAIN:
197+
if mode == tf_estimator.ModeKeys.TRAIN:
197198
tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
198199
tf_update_ops.append(tf.assign_add(global_step, 1))
199200
tf.logging.info('tf_update_ops: {}'.format(tf_update_ops))
@@ -204,7 +205,7 @@ def model_fn(features, labels, mode, params):
204205
with mtf.utils.outside_all_rewrites():
205206
# Copy master variables to slices. Must be called first.
206207
restore_hook = mtf.MtfRestoreHook(lowering)
207-
if mode == tf.estimator.ModeKeys.TRAIN:
208+
if mode == tf_estimator.ModeKeys.TRAIN:
208209
saver = tf.train.Saver(
209210
tf.global_variables(),
210211
sharded=True,
@@ -221,11 +222,11 @@ def model_fn(features, labels, mode, params):
221222
listeners=[saver_listener])
222223

223224
return tpu_estimator.TPUEstimatorSpec(
224-
tf.estimator.ModeKeys.TRAIN,
225+
tf_estimator.ModeKeys.TRAIN,
225226
loss=tf_loss,
226227
train_op=train_op,
227228
training_hooks=[restore_hook, saver_hook])
228-
elif mode == tf.estimator.ModeKeys.EVAL:
229+
elif mode == tf_estimator.ModeKeys.EVAL:
229230

230231
def metric_fn(tf_logits):
231232
mean_logits = tf.metrics.mean(tf_logits)
@@ -234,7 +235,7 @@ def metric_fn(tf_logits):
234235
eval_metrics = (metric_fn, [tf_logits])
235236

236237
return tpu_estimator.TPUEstimatorSpec(
237-
tf.estimator.ModeKeys.EVAL,
238+
tf_estimator.ModeKeys.EVAL,
238239
evaluation_hooks=[restore_hook],
239240
loss=tf_loss,
240241
eval_metrics=eval_metrics)

mesh_tensorflow/bert/run_classifier.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import mesh_tensorflow.bert.tokenization as tokenization
3030
from six.moves import range
3131
import tensorflow.compat.v1 as tf
32+
from tensorflow.compat.v1 import estimator as tf_estimator
3233

3334
flags = tf.flags
3435

@@ -694,7 +695,7 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
694695
[batch_dim, seq_dim])
695696
mtf_label_ids = mtf.import_tf_tensor(mesh, label_ids, [batch_dim])
696697

697-
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
698+
is_training = (mode == tf_estimator.ModeKeys.TRAIN)
698699

699700
(total_loss, per_example_loss, logits,
700701
probabilities) = create_model(bert_config, is_training, mtf_input_ids,
@@ -705,7 +706,7 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
705706
per_example_loss = mtf.anonymize(per_example_loss)
706707
logits = mtf.anonymize(logits)
707708

708-
if mode == tf.estimator.ModeKeys.TRAIN:
709+
if mode == tf_estimator.ModeKeys.TRAIN:
709710
_, update_ops = optimization_lib.create_optimizer(
710711
total_loss,
711712
learning_rate,
@@ -718,13 +719,13 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
718719
lowering = mtf.Lowering(graph, {mesh: mesh_impl})
719720
tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss))
720721

721-
if mode == tf.estimator.ModeKeys.TRAIN:
722+
if mode == tf_estimator.ModeKeys.TRAIN:
722723
global_step = tf.train.get_global_step()
723724
tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
724725
tf_update_ops.append(tf.assign_add(global_step, 1))
725726
tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
726727
train_op = tf.group(tf_update_ops)
727-
elif mode == tf.estimator.ModeKeys.EVAL:
728+
elif mode == tf_estimator.ModeKeys.EVAL:
728729

729730
def metric_fn(per_example_loss, label_ids, logits, is_real_example):
730731
predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
@@ -768,7 +769,7 @@ def tpu_scaffold():
768769
with mtf.utils.outside_all_rewrites():
769770
# Copy master variables to slices. Must be called first.
770771
restore_hook = mtf.MtfRestoreHook(lowering)
771-
if mode == tf.estimator.ModeKeys.TRAIN:
772+
if mode == tf_estimator.ModeKeys.TRAIN:
772773
saver = tf.train.Saver(
773774
tf.global_variables(),
774775
sharded=True,
@@ -784,21 +785,21 @@ def tpu_scaffold():
784785
saver=saver,
785786
listeners=[saver_listener])
786787

787-
return tf.estimator.tpu.TPUEstimatorSpec(
788+
return tf_estimator.tpu.TPUEstimatorSpec(
788789
mode,
789790
loss=tf_loss,
790791
train_op=train_op,
791792
training_hooks=[restore_hook, saver_hook],
792793
scaffold_fn=scaffold_fn)
793-
elif mode == tf.estimator.ModeKeys.EVAL:
794-
return tf.estimator.tpu.TPUEstimatorSpec(
794+
elif mode == tf_estimator.ModeKeys.EVAL:
795+
return tf_estimator.tpu.TPUEstimatorSpec(
795796
mode,
796797
evaluation_hooks=[restore_hook],
797798
loss=tf_loss,
798799
eval_metrics=eval_metrics,
799800
scaffold_fn=scaffold_fn)
800801
else:
801-
return tf.estimator.tpu.TPUEstimatorSpec(
802+
return tf_estimator.tpu.TPUEstimatorSpec(
802803
mode,
803804
prediction_hooks=[restore_hook],
804805
predictions={
@@ -925,15 +926,15 @@ def main(_):
925926
tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
926927
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
927928

928-
run_config = tf.estimator.tpu.RunConfig(
929+
run_config = tf_estimator.tpu.RunConfig(
929930
cluster=tpu_cluster_resolver,
930931
master=FLAGS.master,
931932
model_dir=FLAGS.output_dir,
932933
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
933-
tpu_config=tf.estimator.tpu.TPUConfig(
934+
tpu_config=tf_estimator.tpu.TPUConfig(
934935
iterations_per_loop=FLAGS.iterations_per_loop,
935936
num_cores_per_replica=1,
936-
per_host_input_for_training=tf.estimator.tpu.InputPipelineConfig
937+
per_host_input_for_training=tf_estimator.tpu.InputPipelineConfig
937938
.BROADCAST))
938939

939940
train_examples = None
@@ -956,7 +957,7 @@ def main(_):
956957

957958
# If TPU is not available, this will fall back to normal Estimator on CPU
958959
# or GPU.
959-
estimator = tf.estimator.tpu.TPUEstimator(
960+
estimator = tf_estimator.tpu.TPUEstimator(
960961
use_tpu=FLAGS.use_tpu,
961962
model_fn=model_fn,
962963
config=run_config,

mesh_tensorflow/bert/run_pretraining.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import mesh_tensorflow.bert.optimization as optimization_lib
2828
from six.moves import range
2929
import tensorflow.compat.v1 as tf
30+
from tensorflow.compat.v1 import estimator as tf_estimator
3031

3132
flags = tf.flags
3233

@@ -201,7 +202,7 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
201202
mtf_next_sentence_labels = mtf.import_tf_tensor(
202203
mesh, next_sentence_labels, [batch_dim])
203204

204-
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
205+
is_training = (mode == tf_estimator.ModeKeys.TRAIN)
205206

206207
model = bert_lib.BertModel(
207208
config=bert_config,
@@ -230,7 +231,7 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
230231
next_sentence_logits = mtf.anonymize(next_sentence_logits)
231232

232233
# TRAIN mode
233-
if mode == tf.estimator.ModeKeys.TRAIN:
234+
if mode == tf_estimator.ModeKeys.TRAIN:
234235
_, update_ops = optimization_lib.create_optimizer(
235236
total_loss + extra_loss,
236237
learning_rate,
@@ -243,13 +244,13 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
243244

244245
tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss))
245246

246-
if mode == tf.estimator.ModeKeys.TRAIN:
247+
if mode == tf_estimator.ModeKeys.TRAIN:
247248
global_step = tf.train.get_global_step()
248249
tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
249250
tf_update_ops.append(tf.assign_add(global_step, 1))
250251
tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
251252
train_op = tf.group(tf_update_ops)
252-
elif mode == tf.estimator.ModeKeys.EVAL:
253+
elif mode == tf_estimator.ModeKeys.EVAL:
253254

254255
def metric_fn(masked_lm_example_loss, masked_lm_logits, masked_lm_ids,
255256
masked_lm_weights, next_sentence_example_loss,
@@ -298,7 +299,7 @@ def metric_fn(masked_lm_example_loss, masked_lm_logits, masked_lm_ids,
298299
with mtf.utils.outside_all_rewrites():
299300
# Copy master variables to slices. Must be called first.
300301
restore_hook = mtf.MtfRestoreHook(lowering)
301-
if mode == tf.estimator.ModeKeys.TRAIN:
302+
if mode == tf_estimator.ModeKeys.TRAIN:
302303
saver = tf.train.Saver(
303304
tf.global_variables(),
304305
sharded=True,
@@ -314,14 +315,14 @@ def metric_fn(masked_lm_example_loss, masked_lm_logits, masked_lm_ids,
314315
saver=saver,
315316
listeners=[saver_listener])
316317

317-
return tf.estimator.tpu.TPUEstimatorSpec(
318-
tf.estimator.ModeKeys.TRAIN,
318+
return tf_estimator.tpu.TPUEstimatorSpec(
319+
tf_estimator.ModeKeys.TRAIN,
319320
loss=tf_loss,
320321
train_op=train_op,
321322
training_hooks=[restore_hook, saver_hook])
322-
elif mode == tf.estimator.ModeKeys.EVAL:
323-
return tf.estimator.tpu.TPUEstimatorSpec(
324-
tf.estimator.ModeKeys.EVAL,
323+
elif mode == tf_estimator.ModeKeys.EVAL:
324+
return tf_estimator.tpu.TPUEstimatorSpec(
325+
tf_estimator.ModeKeys.EVAL,
325326
evaluation_hooks=[restore_hook],
326327
loss=tf_loss,
327328
eval_metrics=eval_metrics)
@@ -439,15 +440,15 @@ def main(_):
439440
tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
440441
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
441442

442-
run_config = tf.estimator.tpu.RunConfig(
443+
run_config = tf_estimator.tpu.RunConfig(
443444
cluster=tpu_cluster_resolver,
444445
master=FLAGS.master,
445446
model_dir=FLAGS.output_dir,
446447
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
447-
tpu_config=tf.estimator.tpu.TPUConfig(
448+
tpu_config=tf_estimator.tpu.TPUConfig(
448449
iterations_per_loop=FLAGS.iterations_per_loop,
449450
num_cores_per_replica=1,
450-
per_host_input_for_training=tf.estimator.tpu.InputPipelineConfig
451+
per_host_input_for_training=tf_estimator.tpu.InputPipelineConfig
451452
.BROADCAST))
452453

453454
model_fn = model_fn_builder(
@@ -459,7 +460,7 @@ def main(_):
459460

460461
# If TPU is not available, this will fall back to normal Estimator on CPU
461462
# or GPU.
462-
estimator = tf.estimator.tpu.TPUEstimator(
463+
estimator = tf_estimator.tpu.TPUEstimator(
463464
use_tpu=FLAGS.use_tpu,
464465
model_fn=model_fn,
465466
config=run_config,

0 commit comments

Comments
 (0)