Skip to content

Commit 968f9e0

Browse files
author
Lee Yang
committed
add compatibility w/ TF1.x
1 parent 9f3cd53 commit 968f9e0

File tree

7 files changed

+317
-25
lines changed

7 files changed

+317
-25
lines changed

examples/mnist/estimator/mnist_pipeline.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@ def input_fn(mode, input_context=None):
4545
ds = tf.data.Dataset.from_generator(rdd_generator, (tf.float32, tf.float32), (tf.TensorShape([28, 28, 1]), tf.TensorShape([1])))
4646
return ds.batch(BATCH_SIZE)
4747
else:
48-
raise Exception("I'm evaluating: mode={}, input_context={}".format(mode, input_context))
49-
48+
# read evaluation data from tensorflow_datasets directly
5049
def scale(image, label):
5150
image = tf.cast(image, tf.float32) / 255.0
5251
return image, label

examples/mnist/keras/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ For batch inferencing use cases, you can use Spark to run multiple single-node T
130130
${TFoS_HOME}/examples/mnist/keras/mnist_inference.py \
131131
--cluster_size ${SPARK_WORKER_INSTANCES} \
132132
--images_labels ${TFoS_HOME}/data/mnist/tfr/test \
133-
--export_dir ${TFoS_HOME}/mnist_export \
133+
--export_dir $SAVED_MODEL \
134134
--output ${TFoS_HOME}/predictions
135135

136136
#### Train and Inference via Spark ML Pipeline API

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
h5py>=2.9.0
22
numpy>=1.14.0
3+
packaging
34
py4j==0.10.7
45
pyspark
56
scipy

tensorflowonspark/TFNode.py

Lines changed: 131 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import getpass
1818
import logging
19+
20+
from packaging import version
1921
from six.moves.queue import Empty
2022
from . import marker
2123

@@ -61,8 +63,86 @@ def hdfs_path(ctx, path):
6163

6264

6365
def start_cluster_server(ctx, num_gpus=1, rdma=False):
64-
"""*DEPRECATED*. Use higher-level APIs like `tf.keras` or `tf.estimator`"""
65-
raise Exception("DEPRECATED: Use higher-level APIs like `tf.keras` or `tf.estimator`")
66+
"""Function that wraps the creation of TensorFlow ``tf.train.Server`` for a node in a distributed TensorFlow cluster.
67+
68+
This is intended to be invoked from within the TF ``map_fun``, replacing explicit code to instantiate ``tf.train.ClusterSpec``
69+
and ``tf.train.Server`` objects.
70+
71+
DEPRECATED for TensorFlow 2.x+
72+
73+
Args:
74+
:ctx: TFNodeContext containing the metadata specific to this node in the cluster.
75+
:num_gpu: number of GPUs desired
76+
:rdma: boolean indicating if RDMA 'iverbs' should be used for cluster communications.
77+
78+
Returns:
79+
A tuple of (cluster_spec, server)
80+
"""
81+
import os
82+
import tensorflow as tf
83+
import time
84+
from . import gpu_info
85+
86+
if version.parse(tf.__version__) >= version.parse("2.0.0"):
87+
raise Exception("DEPRECATED: Use higher-level APIs like `tf.keras` or `tf.estimator`")
88+
89+
logging.info("{0}: ======== {1}:{2} ========".format(ctx.worker_num, ctx.job_name, ctx.task_index))
90+
cluster_spec = ctx.cluster_spec
91+
logging.info("{0}: Cluster spec: {1}".format(ctx.worker_num, cluster_spec))
92+
93+
if tf.test.is_built_with_cuda() and num_gpus > 0:
94+
# compute my index relative to other nodes placed on the same host (for GPU allocation)
95+
my_addr = cluster_spec[ctx.job_name][ctx.task_index]
96+
my_host = my_addr.split(':')[0]
97+
flattened = [v for sublist in cluster_spec.values() for v in sublist]
98+
local_peers = [p for p in flattened if p.startswith(my_host)]
99+
my_index = local_peers.index(my_addr)
100+
101+
# GPU
102+
gpu_initialized = False
103+
retries = 3
104+
while not gpu_initialized and retries > 0:
105+
try:
106+
# override PS jobs to only reserve one GPU
107+
if ctx.job_name == 'ps':
108+
num_gpus = 0
109+
110+
# Find a free gpu(s) to use
111+
gpus_to_use = gpu_info.get_gpus(num_gpus, my_index)
112+
gpu_prompt = "GPU" if num_gpus == 1 else "GPUs"
113+
logging.info("{0}: Using {1}: {2}".format(ctx.worker_num, gpu_prompt, gpus_to_use))
114+
115+
# Set GPU device to use for TensorFlow
116+
os.environ['CUDA_VISIBLE_DEVICES'] = gpus_to_use
117+
118+
# Create a cluster from the parameter server and worker hosts.
119+
cluster = tf.train.ClusterSpec(cluster_spec)
120+
121+
# Create and start a server for the local task.
122+
if rdma:
123+
server = tf.train.Server(cluster, ctx.job_name, ctx.task_index, protocol="grpc+verbs")
124+
else:
125+
server = tf.train.Server(cluster, ctx.job_name, ctx.task_index)
126+
gpu_initialized = True
127+
except Exception as e:
128+
print(e)
129+
logging.error("{0}: Failed to allocate GPU, trying again...".format(ctx.worker_num))
130+
retries -= 1
131+
time.sleep(10)
132+
if not gpu_initialized:
133+
raise Exception("Failed to allocate GPU")
134+
else:
135+
# CPU
136+
os.environ['CUDA_VISIBLE_DEVICES'] = ''
137+
logging.info("{0}: Using CPU".format(ctx.worker_num))
138+
139+
# Create a cluster from the parameter server and worker hosts.
140+
cluster = tf.train.ClusterSpec(cluster_spec)
141+
142+
# Create and start a server for the local task.
143+
server = tf.train.Server(cluster, ctx.job_name, ctx.task_index)
144+
145+
return (cluster, server)
66146

67147

68148
def next_batch(mgr, batch_size, qname='input'):
@@ -71,8 +151,55 @@ def next_batch(mgr, batch_size, qname='input'):
71151

72152

73153
def export_saved_model(sess, export_dir, tag_set, signatures):
74-
"""*DEPRECATED*. Use TF provided APIs instead."""
75-
raise Exception("DEPRECATED: Use TF provided APIs instead.")
154+
"""Convenience function to export a saved_model using provided arguments
155+
156+
The caller specifies the saved_model signatures in a simplified python dictionary form, as follows::
157+
158+
signatures = {
159+
'signature_def_key': {
160+
'inputs': { 'input_tensor_alias': input_tensor_name },
161+
'outputs': { 'output_tensor_alias': output_tensor_name },
162+
'method_name': 'method'
163+
}
164+
}
165+
166+
And this function will generate the `signature_def_map` and export the saved_model.
167+
168+
DEPRECATED for TensorFlow 2.x+.
169+
170+
Args:
171+
:sess: a tf.Session instance
172+
:export_dir: path to save exported saved_model
173+
:tag_set: string tag_set to identify the exported graph
174+
:signatures: simplified dictionary representation of a TensorFlow signature_def_map
175+
176+
Returns:
177+
A saved_model exported to disk at ``export_dir``.
178+
"""
179+
import tensorflow as tf
180+
181+
if version.parse(tf.__version__) >= version.parse("2.0.0"):
182+
raise Exception("DEPRECATED: Use TF provided APIs instead.")
183+
184+
g = sess.graph
185+
g._unsafe_unfinalize() # https://github.com/tensorflow/serving/issues/363
186+
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
187+
188+
logging.info("===== signatures: {}".format(signatures))
189+
signature_def_map = {}
190+
for key, sig in signatures.items():
191+
signature_def_map[key] = tf.saved_model.signature_def_utils.build_signature_def(
192+
inputs={name: tf.saved_model.utils.build_tensor_info(tensor) for name, tensor in sig['inputs'].items()},
193+
outputs={name: tf.saved_model.utils.build_tensor_info(tensor) for name, tensor in sig['outputs'].items()},
194+
method_name=sig['method_name'] if 'method_name' in sig else key)
195+
logging.info("===== signature_def_map: {}".format(signature_def_map))
196+
builder.add_meta_graph_and_variables(
197+
sess,
198+
tag_set.split(','),
199+
signature_def_map=signature_def_map,
200+
clear_devices=True)
201+
g.finalize()
202+
builder.save()
76203

77204

78205
def batch_results(mgr, results, qname='output'):

tensorflowonspark/compat.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright 2019 Yahoo Inc / Verizon Media
2+
# Licensed under the terms of the Apache 2.0 license.
3+
# Please see LICENSE file in the project root for terms.
4+
"""Helper functions to abstract API changes between TensorFlow versions."""
5+
6+
import tensorflow as tf
7+
8+
TF_VERSION = tf.__version__
9+
10+
11+
def export_saved_model(model, export_dir, is_chief=False):
12+
if TF_VERSION == '2.0.0':
13+
if is_chief:
14+
tf.keras.experimental.export_saved_model(model, export_dir)
15+
else:
16+
model.save(export_dir, save_format='tf')
17+
18+
19+
def disable_auto_shard(options):
20+
if TF_VERSION == '2.0.0':
21+
options.experimental_distribute.auto_shard = False
22+
else:
23+
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF

0 commit comments

Comments
 (0)