Skip to content

Commit ce5e789

Browse files
authored
Merge pull request #183 from winston-zillow/master
Allow running PS nodes on the spark driver
2 parents 40865ae + 8e690a1 commit ce5e789

File tree

6 files changed

+49
-7
lines changed

6 files changed

+49
-7
lines changed

examples/mnist/tf/mnist_spark.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,14 @@
3333
parser.add_argument("-tb", "--tensorboard", help="launch tensorboard process", action="store_true")
3434
parser.add_argument("-X", "--mode", help="train|inference", default="train")
3535
parser.add_argument("-c", "--rdma", help="use rdma connection", default=False)
36+
parser.add_argument("-p", "--driver_ps_nodes", help="run tensorflow PS node on driver locally", default=False)
3637
args = parser.parse_args()
3738
print("args:",args)
3839

3940

4041
print("{0} ===== Start".format(datetime.now().isoformat()))
41-
cluster = TFCluster.run(sc, mnist_dist.map_fun, args, args.cluster_size, num_ps, args.tensorboard, TFCluster.InputMode.TENSORFLOW, log_dir=args.model)
42+
cluster = TFCluster.run(sc, mnist_dist.map_fun, args, args.cluster_size, num_ps, args.tensorboard, TFCluster.InputMode.TENSORFLOW,
43+
driver_ps_nodes=args.driver_ps_nodes, log_dir=args.model)
4244
cluster.shutdown()
4345

4446
print("{0} ===== Stop".format(datetime.now().isoformat()))

examples/mnist/tf/mnist_spark_dataset.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,15 @@
3434
parser.add_argument("-tb", "--tensorboard", help="launch tensorboard process", action="store_true")
3535
parser.add_argument("-X", "--mode", help="train|inference", default="train")
3636
parser.add_argument("-c", "--rdma", help="use rdma connection", default=False)
37+
parser.add_argument("-p", "--driver_ps_nodes", help="""run tensorflow PS node on driver locally.
38+
You will need to set cluster_size = num_executors + num_ps""", default=False)
3739
args = parser.parse_args()
3840
print("args:",args)
3941

4042

4143
print("{0} ===== Start".format(datetime.now().isoformat()))
42-
cluster = TFCluster.run(sc, mnist_dist_dataset.map_fun, args, args.cluster_size, num_ps, args.tensorboard, TFCluster.InputMode.TENSORFLOW)
44+
cluster = TFCluster.run(sc, mnist_dist_dataset.map_fun, args, args.cluster_size, num_ps, args.tensorboard,
45+
TFCluster.InputMode.TENSORFLOW, driver_ps_nodes=args.driver_ps_nodes)
4346
cluster.shutdown()
4447

4548
print("{0} ===== Stop".format(datetime.now().isoformat()))

examples/mnist/tf/mnist_spark_pipeline.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
parser.add_argument("--tfrecord_dir", help="HDFS path to temporarily save DataFrame to disk", type=str)
3939
parser.add_argument("--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors)
4040
parser.add_argument("--num_ps", help="number of PS nodes in cluster", type=int, default=1)
41+
parser.add_argument("-p", "--driver_ps_nodes", help="""run tensorflow PS node on driver locally.
42+
You will need to set cluster_size = num_executors + num_ps""", default=False)
4143
parser.add_argument("--protocol", help="Tensorflow network protocol (grpc|rdma)", default="grpc")
4244
parser.add_argument("--readers", help="number of reader/enqueue threads", type=int, default=1)
4345
parser.add_argument("--steps", help="maximum number of steps", type=int, default=1000)
@@ -81,6 +83,7 @@
8183
.setExportDir(args.export_dir) \
8284
.setClusterSize(args.cluster_size) \
8385
.setNumPS(args.num_ps) \
86+
.setDriverPSNodes(args.driver_ps_nodes) \
8487
.setInputMode(TFCluster.InputMode.TENSORFLOW) \
8588
.setTFRecordDir(args.tfrecord_dir) \
8689
.setProtocol(args.protocol) \

tensorflowonspark/TFCluster.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,8 @@ def tensorboard_url(self):
186186
tb_url = "http://{0}:{1}".format(node['host'], node['tb_port'])
187187
return tb_url
188188

189-
def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mode=InputMode.TENSORFLOW, log_dir=None, queues=['input', 'output']):
189+
def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mode=InputMode.TENSORFLOW,
190+
log_dir=None, driver_ps_nodes=False, queues=['input', 'output']):
190191
"""Starts the TensorFlowOnSpark cluster and Runs the TensorFlow "main" function on the Spark executors
191192
192193
Args:
@@ -198,6 +199,7 @@ def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mo
198199
:tensorboard: boolean indicating if the chief worker should spawn a Tensorboard server.
199200
:input_mode: TFCluster.InputMode
200201
:log_dir: directory to save tensorboard event logs. If None, defaults to a fixed path on local filesystem.
202+
:driver_ps_nodes: run the PS nodes on the driver locally instead of on the spark executors; this help maximizing computing resources (esp. GPU). You will need to set cluster_size = num_executors + num_ps
201203
:queues: *INTERNAL_USE*
202204
203205
Returns:
@@ -206,10 +208,14 @@ def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mo
206208
logging.info("Reserving TFSparkNodes {0}".format("w/ TensorBoard" if tensorboard else ""))
207209
assert num_ps < num_executors
208210

211+
if driver_ps_nodes and input_mode != InputMode.TENSORFLOW:
212+
raise Exception('running PS nodes on driver locally is only supported in InputMode.TENSORFLOW')
213+
209214
# build a cluster_spec template using worker_nums
210215
cluster_template = {}
211216
cluster_template['ps'] = range(num_ps)
212217
cluster_template['worker'] = range(num_ps, num_executors)
218+
logging.info("worker node range %s, ps node range %s" % (cluster_template['worker'], cluster_template['ps']))
213219

214220
# get default filesystem from spark
215221
defaultFS = sc._jsc.hadoopConfiguration().get("fs.defaultFS")
@@ -234,7 +240,25 @@ def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mo
234240
'working_dir': working_dir,
235241
'server_addr': server_addr
236242
}
237-
nodeRDD = sc.parallelize(range(num_executors), num_executors)
243+
if driver_ps_nodes:
244+
nodeRDD = sc.parallelize(range(num_ps, num_executors), num_executors - num_ps)
245+
else:
246+
nodeRDD = sc.parallelize(range(num_executors), num_executors)
247+
248+
if driver_ps_nodes:
249+
def _start_ps(node_index):
250+
logging.info("starting ps node locally %d" % node_index)
251+
TFSparkNode.run(map_fun,
252+
tf_args,
253+
cluster_meta,
254+
tensorboard,
255+
log_dir,
256+
queues,
257+
background=(input_mode == InputMode.SPARK))([node_index])
258+
for i in cluster_template['ps']:
259+
ps_thread = threading.Thread(target=lambda: _start_ps(i))
260+
ps_thread.daemon = True
261+
ps_thread.start()
238262

239263
# start TF on a background thread (on Spark driver) to allow for feeding job
240264
def _start():
@@ -244,7 +268,7 @@ def _start():
244268
tensorboard,
245269
log_dir,
246270
queues,
247-
(input_mode == InputMode.SPARK)))
271+
background=(input_mode == InputMode.SPARK)))
248272
t = threading.Thread(target=_start)
249273
t.start()
250274

tensorflowonspark/TFSparkNode.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,11 @@ def wrapper_fn(args, context):
270270

271271
if job_name == 'ps' or background:
272272
# invoke the TensorFlow main function in a background thread
273-
logging.info("Starting TensorFlow {0}:{1} on cluster node {2} on background process".format(job_name, task_index, worker_num))
273+
logging.info("Starting TensorFlow {0}:{1} as {2} on cluster node {3} on background process".format(
274+
job_name, task_index, job_name, worker_num))
274275
p = multiprocessing.Process(target=wrapper_fn, args=(tf_args, ctx))
276+
if job_name == 'ps':
277+
p.daemon = True
275278
p.start()
276279

277280
# for ps nodes only, wait indefinitely in foreground thread for a "control" event (None == "stop")

tensorflowonspark/pipeline.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,17 @@ def getModelDir(self):
101101

102102
class HasNumPS(Params):
103103
num_ps = Param(Params._dummy(), "num_ps", "Number of PS nodes in cluster", typeConverter=TypeConverters.toInt)
104+
driver_ps_nodes = Param(Params._dummy(), "driver_ps_nodes", "Run PS nodes on driver locally", typeConverter=TypeConverters.toBoolean)
104105
def __init__(self):
105106
super(HasNumPS, self).__init__()
106107
def setNumPS(self, value):
107108
return self._set(num_ps=value)
108109
def getNumPS(self):
109110
return self.getOrDefault(self.num_ps)
111+
def setDriverPSNodes(self, value):
112+
return self._set(driver_ps_nodes=value)
113+
def getDriverPSNodes(self):
114+
return self.getOrDefault(self.driver_ps_nodes)
110115

111116
class HasOutputMapping(Params):
112117
output_mapping = Param(Params._dummy(), "output_mapping", "Mapping of output tensor to output DataFrame column", typeConverter=TFTypeConverters.toDict)
@@ -276,6 +281,7 @@ def __init__(self, train_fn, tf_args, export_fn=None):
276281
self._setDefault(input_mapping={},
277282
cluster_size=1,
278283
num_ps=0,
284+
driver_ps_nodes=False,
279285
input_mode=TFCluster.InputMode.SPARK,
280286
protocol='grpc',
281287
tensorboard=False,
@@ -319,7 +325,8 @@ def _fit(self, dataset):
319325
logging.info("Done saving")
320326

321327
tf_args = self.args.argv if self.args.argv else local_args
322-
cluster = TFCluster.run(sc, self.train_fn, tf_args, local_args.cluster_size, local_args.num_ps, local_args.tensorboard, local_args.input_mode)
328+
cluster = TFCluster.run(sc, self.train_fn, tf_args, local_args.cluster_size, local_args.num_ps,
329+
local_args.tensorboard, local_args.input_mode, driver_ps_nodes=local_args.driver_ps_nodes)
323330
if local_args.input_mode == TFCluster.InputMode.SPARK:
324331
# feed data, using a deterministic order for input columns (lexicographic by key)
325332
input_cols = sorted(self.getInputMapping().keys())

0 commit comments

Comments
 (0)