@@ -186,7 +186,8 @@ def tensorboard_url(self):
186186 tb_url = "http://{0}:{1}" .format (node ['host' ], node ['tb_port' ])
187187 return tb_url
188188
189- def run (sc , map_fun , tf_args , num_executors , num_ps , tensorboard = False , input_mode = InputMode .TENSORFLOW , log_dir = None , queues = ['input' , 'output' ]):
189+ def run (sc , map_fun , tf_args , num_executors , num_ps , tensorboard = False , input_mode = InputMode .TENSORFLOW ,
190+ log_dir = None , driver_ps_nodes = False , queues = ['input' , 'output' ]):
190191 """Starts the TensorFlowOnSpark cluster and Runs the TensorFlow "main" function on the Spark executors
191192
192193 Args:
@@ -198,6 +199,7 @@ def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mo
198199 :tensorboard: boolean indicating if the chief worker should spawn a Tensorboard server.
199200 :input_mode: TFCluster.InputMode
200201 :log_dir: directory to save tensorboard event logs. If None, defaults to a fixed path on local filesystem.
202+ :driver_ps_nodes: run the PS nodes on the driver locally instead of on the spark executors; this help maximizing computing resources (esp. GPU). You will need to set cluster_size = num_executors + num_ps
201203 :queues: *INTERNAL_USE*
202204
203205 Returns:
@@ -206,10 +208,14 @@ def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mo
206208 logging .info ("Reserving TFSparkNodes {0}" .format ("w/ TensorBoard" if tensorboard else "" ))
207209 assert num_ps < num_executors
208210
211+ if driver_ps_nodes and input_mode != InputMode .TENSORFLOW :
212+ raise Exception ('running PS nodes on driver locally is only supported in InputMode.TENSORFLOW' )
213+
209214 # build a cluster_spec template using worker_nums
210215 cluster_template = {}
211216 cluster_template ['ps' ] = range (num_ps )
212217 cluster_template ['worker' ] = range (num_ps , num_executors )
218+ logging .info ("worker node range %s, ps node range %s" % (cluster_template ['worker' ], cluster_template ['ps' ]))
213219
214220 # get default filesystem from spark
215221 defaultFS = sc ._jsc .hadoopConfiguration ().get ("fs.defaultFS" )
@@ -234,7 +240,25 @@ def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mo
234240 'working_dir' : working_dir ,
235241 'server_addr' : server_addr
236242 }
237- nodeRDD = sc .parallelize (range (num_executors ), num_executors )
243+ if driver_ps_nodes :
244+ nodeRDD = sc .parallelize (range (num_ps , num_executors ), num_executors - num_ps )
245+ else :
246+ nodeRDD = sc .parallelize (range (num_executors ), num_executors )
247+
248+ if driver_ps_nodes :
249+ def _start_ps (node_index ):
250+ logging .info ("starting ps node locally %d" % node_index )
251+ TFSparkNode .run (map_fun ,
252+ tf_args ,
253+ cluster_meta ,
254+ tensorboard ,
255+ log_dir ,
256+ queues ,
257+ background = (input_mode == InputMode .SPARK ))([node_index ])
258+ for i in cluster_template ['ps' ]:
259+ ps_thread = threading .Thread (target = lambda : _start_ps (i ))
260+ ps_thread .daemon = True
261+ ps_thread .start ()
238262
239263 # start TF on a background thread (on Spark driver) to allow for feeding job
240264 def _start ():
@@ -244,7 +268,7 @@ def _start():
244268 tensorboard ,
245269 log_dir ,
246270 queues ,
247- (input_mode == InputMode .SPARK )))
271+ background = (input_mode == InputMode .SPARK )))
248272 t = threading .Thread (target = _start )
249273 t .start ()
250274
0 commit comments