Skip to content

Commit 61f7bfb

Browse files
authored
Merge pull request #327 from yahoo/leewyang_estimator_inference
Distributed inferencing via estimator.predict
2 parents 7d922f3 + b5db074 commit 61f7bfb

File tree

2 files changed

+91
-53
lines changed

2 files changed

+91
-53
lines changed

examples/mnist/keras/mnist_mlp_estimator.py

Lines changed: 80 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -39,46 +39,43 @@ def main_fun(args, ctx):
3939
estimator = tf.keras.estimator.model_to_estimator(model, model_dir=args.model_dir)
4040

4141
# setup train_input_fn for InputMode.TENSORFLOW or InputMode.SPARK
42-
if args.input_mode == 'tf':
43-
train_input_fn = tf.estimator.inputs.numpy_input_fn(
44-
x={"dense_1_input": x_train},
45-
y=y_train,
46-
batch_size=128,
47-
num_epochs=None,
48-
shuffle=True)
49-
else: # 'spark'
50-
tf_feed = TFNode.DataFeed(ctx.mgr)
51-
52-
def rdd_generator():
53-
while not tf_feed.should_stop():
54-
batch = tf_feed.next_batch(1)
55-
if len(batch) > 0:
56-
record = batch[0]
57-
image = numpy.array(record[0]).astype(numpy.float32) / 255.0
58-
label = numpy.array(record[1]).astype(numpy.float32)
59-
yield (image, label)
60-
61-
def train_input_fn():
62-
ds = tf.data.Dataset.from_generator(rdd_generator,
63-
(tf.float32, tf.float32),
64-
(tf.TensorShape([IMAGE_PIXELS * IMAGE_PIXELS]), tf.TensorShape([10])))
65-
ds = ds.batch(args.batch_size)
66-
return ds
67-
68-
# eval_input_fn ALWAYS uses data loaded in memory, since InputMode.SPARK can only feed one RDD at a time
69-
eval_input_fn = tf.estimator.inputs.numpy_input_fn(
70-
x={"dense_1_input": x_test},
71-
y=y_test,
72-
num_epochs=args.epochs,
73-
shuffle=False)
74-
75-
# setup tf.estimator.train_and_evaluate()
76-
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=args.steps)
77-
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
78-
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
79-
80-
# export a saved_model, if export_dir provided
81-
if args.export_dir:
42+
if args.mode == 'train':
43+
if args.input_mode == 'tf':
44+
# For InputMode.TENSORFLOW, just use data in memory
45+
train_input_fn = tf.estimator.inputs.numpy_input_fn(
46+
x={"dense_1_input": x_train},
47+
y=y_train,
48+
batch_size=128,
49+
num_epochs=None,
50+
shuffle=True)
51+
else: # 'spark'
52+
# For InputMode.SPARK, read data from RDD
53+
tf_feed = TFNode.DataFeed(ctx.mgr)
54+
55+
def rdd_generator():
56+
while not tf_feed.should_stop():
57+
batch = tf_feed.next_batch(1)
58+
if len(batch) > 0:
59+
record = batch[0]
60+
image = numpy.array(record[0]).astype(numpy.float32) / 255.0
61+
label = numpy.array(record[1]).astype(numpy.float32)
62+
yield (image, label)
63+
64+
def train_input_fn():
65+
ds = tf.data.Dataset.from_generator(rdd_generator,
66+
(tf.float32, tf.float32),
67+
(tf.TensorShape([IMAGE_PIXELS * IMAGE_PIXELS]), tf.TensorShape([10])))
68+
ds = ds.batch(args.batch_size)
69+
return ds
70+
71+
# eval_input_fn ALWAYS uses data loaded in memory, since InputMode.SPARK can only feed one RDD at a time
72+
eval_input_fn = tf.estimator.inputs.numpy_input_fn(
73+
x={"dense_1_input": x_test},
74+
y=y_test,
75+
num_epochs=args.epochs,
76+
shuffle=False)
77+
78+
# serving_input_receiver_fn ALWAYS expects serialized TFExamples in a placeholder.
8279
def serving_input_receiver_fn():
8380
"""An input receiver that expects a serialized tf.Example."""
8481
serialized_tf_example = tf.placeholder(dtype=tf.string,
@@ -89,7 +86,35 @@ def serving_input_receiver_fn():
8986
features = tf.parse_example(serialized_tf_example, feature_spec)
9087
return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
9188

92-
estimator.export_savedmodel(args.export_dir, serving_input_receiver_fn)
89+
# setup tf.estimator.train_and_evaluate() w/ FinalExporter
90+
exporter = tf.estimator.FinalExporter("serving", serving_input_receiver_fn=serving_input_receiver_fn)
91+
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=args.steps)
92+
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, exporters=exporter)
93+
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
94+
95+
else: # mode == 'inference'
96+
if args.input_mode == 'spark':
97+
tf_feed = TFNode.DataFeed(ctx.mgr)
98+
99+
def rdd_generator():
100+
while not tf_feed.should_stop():
101+
batch = tf_feed.next_batch(1)
102+
if len(batch) > 0:
103+
record = batch[0]
104+
image = numpy.array(record[0]).astype(numpy.float32) / 255.0
105+
label = numpy.array(record[1]).astype(numpy.float32)
106+
yield (image, label)
107+
108+
def predict_input_fn():
109+
ds = tf.data.Dataset.from_generator(rdd_generator,
110+
(tf.float32, tf.float32),
111+
(tf.TensorShape([IMAGE_PIXELS * IMAGE_PIXELS]), tf.TensorShape([10])))
112+
ds = ds.batch(args.batch_size)
113+
return ds
114+
115+
predictions = estimator.predict(predict_input_fn)
116+
for result in predictions:
117+
tf_feed.batch_results([result])
93118

94119

95120
if __name__ == '__main__':
@@ -112,6 +137,8 @@ def serving_input_receiver_fn():
112137
parser.add_argument("--input_mode", help="input mode (tf|spark)", default="tf")
113138
parser.add_argument("--labels", help="HDFS path to MNIST labels in parallelized CSV format")
114139
parser.add_argument("--model_dir", help="directory to write model checkpoints")
140+
parser.add_argument("--mode", help="(train|inference")
141+
parser.add_argument("--output", help="HDFS path to save test/inference output", default="predictions")
115142
parser.add_argument("--num_ps", help="number of ps nodes", type=int, default=1)
116143
parser.add_argument("--steps", help="max number of steps to train", type=int, default=2000)
117144
parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")
@@ -120,14 +147,22 @@ def serving_input_receiver_fn():
120147
print("args:", args)
121148

122149
if args.input_mode == 'tf':
123-
# for TENSORFLOW mode, each node will load/train entire dataset in memory per original example
150+
# for TENSORFLOW mode, each node will load/train/infer entire dataset in memory per original example
124151
cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, args.num_ps, args.tensorboard, TFCluster.InputMode.TENSORFLOW, log_dir=args.model_dir, master_node='master')
125152
cluster.shutdown()
126153
else: # 'spark'
127154
# for SPARK mode, just use CSV format as an example
128155
images = sc.textFile(args.images).map(lambda ln: [float(x) for x in ln.split(',')])
129156
labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])
130157
dataRDD = images.zip(labels)
131-
cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, args.num_ps, args.tensorboard, TFCluster.InputMode.SPARK, log_dir=args.model_dir, master_node='master')
132-
cluster.train(dataRDD, args.epochs)
133-
cluster.shutdown()
158+
if args.mode == 'train':
159+
cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, args.num_ps, args.tensorboard, TFCluster.InputMode.SPARK, log_dir=args.model_dir, master_node='master')
160+
cluster.train(dataRDD, args.epochs)
161+
cluster.shutdown()
162+
else:
163+
# Note: using "parallel" inferencing, not "cluster"
164+
# each node loads the model and runs independently of others
165+
cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, 0, args.tensorboard, TFCluster.InputMode.SPARK, log_dir=args.model_dir)
166+
resultRDD = cluster.inference(dataRDD)
167+
resultRDD.saveAsTextFile(args.output)
168+
cluster.shutdown()

tensorflowonspark/TFSparkNode.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -271,14 +271,17 @@ def _mapfn(iter):
271271
hosts.append("{0}:{1}".format(nhost, nport))
272272
spec[njob] = hosts
273273

274-
# update TF_CONFIG and reserve GPU for tf.estimator based code
275-
# Note: this will execute but be ignored by non-tf.estimator code
276-
tf_config = json.dumps({
277-
'cluster': spec,
278-
'task': {'type': job_name, 'index': task_index},
279-
'environment': 'cloud'
280-
})
281-
os.environ['TF_CONFIG'] = tf_config
274+
# update TF_CONFIG if cluster spec has a 'master' node (i.e. tf.estimator)
275+
if 'master' in spec:
276+
tf_config = json.dumps({
277+
'cluster': spec,
278+
'task': {'type': job_name, 'index': task_index},
279+
'environment': 'cloud'
280+
})
281+
logging.info("export TF_CONFIG: {}".format(tf_config))
282+
os.environ['TF_CONFIG'] = tf_config
283+
284+
# reserve GPU
282285
if tf.test.is_built_with_cuda():
283286
num_gpus = tf_args.num_gpus if 'num_gpus' in tf_args else 1
284287
gpus_to_use = gpu_info.get_gpus(num_gpus)

0 commit comments

Comments
 (0)