@@ -137,6 +137,7 @@ def run(fn, tf_args, cluster_meta, tensorboard, log_dir, queues, background):
137137 """
138138 def _mapfn (iter ):
139139 import tensorflow as tf
140+ from packaging import version
140141
141142 # Note: consuming the input iterator helps Pyspark re-use this worker,
142143 for i in iter :
@@ -198,10 +199,12 @@ def _mapfn(iter):
198199 logger .debug ("CLASSPATH: {0}" .format (hadoop_classpath ))
199200 os .environ ['CLASSPATH' ] = classpath + os .pathsep + hadoop_classpath
200201
201- # start TensorBoard if requested
202+ # start TensorBoard if requested, on 'worker:0' if available (for backwards-compatibility), otherwise on 'chief:0' or 'master:0'
203+ job_names = sorted ([k for k in cluster_template .keys () if k in ['chief' , 'master' , 'worker' ]])
204+ tb_job_name = 'worker' if 'worker' in job_names else job_names [0 ]
202205 tb_pid = 0
203206 tb_port = 0
204- if tensorboard and job_name == 'worker' and task_index == 0 :
207+ if tensorboard and job_name == tb_job_name and task_index == 0 :
205208 tb_sock = socket .socket (socket .AF_INET , socket .SOCK_STREAM )
206209 tb_sock .bind (('' , 0 ))
207210 tb_port = tb_sock .getsockname ()[1 ]
@@ -223,7 +226,11 @@ def _mapfn(iter):
223226 raise Exception ("Unable to find 'tensorboard' in: {}" .format (search_path ))
224227
225228 # launch tensorboard
226- tb_proc = subprocess .Popen ([pypath , tb_path , "--logdir=%s" % logdir , "--port=%d" % tb_port ], env = os .environ )
229+ if version .parse (tf .__version__ ) >= version .parse ('2.0.0' ):
230+ tb_proc = subprocess .Popen ([pypath , tb_path , "--reload_multifile=True" , "--logdir=%s" % logdir , "--port=%d" % tb_port ], env = os .environ )
231+ else :
232+ tb_proc = subprocess .Popen ([pypath , tb_path , "--logdir=%s" % logdir , "--port=%d" % tb_port ], env = os .environ )
233+
227234 tb_pid = tb_proc .pid
228235
229236 # check server to see if this task is being retried (i.e. already reserved)
0 commit comments