Skip to content

Commit 26ff94d

Browse files
authored
Merge pull request #487 from yahoo/leewyang_tb_chief
Allow running tensorboard on chief/master, if no workers available
2 parents 38b13b9 + 2417f9c commit 26ff94d

File tree

3 files changed

+19
-10
lines changed

3 files changed

+19
-10
lines changed

examples/mnist/keras/mnist_tf.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,10 @@ def build_and_compile_cnn_model():
5656
# callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath=args.model_dir)]
5757
tf.io.gfile.makedirs(args.model_dir)
5858
filepath = args.model_dir + "/weights-{epoch:04d}"
59-
callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath=filepath, verbose=1, save_weights_only=True)]
59+
callbacks = [
60+
tf.keras.callbacks.ModelCheckpoint(filepath=filepath, verbose=1, save_weights_only=True),
61+
tf.keras.callbacks.TensorBoard(log_dir=args.model_dir)
62+
]
6063

6164
with strategy.scope():
6265
multi_worker_model = build_and_compile_cnn_model()
@@ -90,5 +93,5 @@ def build_and_compile_cnn_model():
9093
args = parser.parse_args()
9194
print("args:", args)
9295

93-
cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, num_ps=0, tensorboard=args.tensorboard, input_mode=TFCluster.InputMode.TENSORFLOW, master_node='chief')
96+
cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, num_ps=0, tensorboard=args.tensorboard, input_mode=TFCluster.InputMode.TENSORFLOW, master_node='chief', log_dir=args.model_dir)
9497
cluster.shutdown()

tensorflowonspark/TFSparkNode.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def run(fn, tf_args, cluster_meta, tensorboard, log_dir, queues, background):
137137
"""
138138
def _mapfn(iter):
139139
import tensorflow as tf
140+
from packaging import version
140141

141142
# Note: consuming the input iterator helps Pyspark re-use this worker,
142143
for i in iter:
@@ -198,10 +199,12 @@ def _mapfn(iter):
198199
logger.debug("CLASSPATH: {0}".format(hadoop_classpath))
199200
os.environ['CLASSPATH'] = classpath + os.pathsep + hadoop_classpath
200201

201-
# start TensorBoard if requested
202+
# start TensorBoard if requested, on 'worker:0' if available (for backwards-compatibility), otherwise on 'chief:0' or 'master:0'
203+
job_names = sorted([k for k in cluster_template.keys() if k in ['chief', 'master', 'worker']])
204+
tb_job_name = 'worker' if 'worker' in job_names else job_names[0]
202205
tb_pid = 0
203206
tb_port = 0
204-
if tensorboard and job_name == 'worker' and task_index == 0:
207+
if tensorboard and job_name == tb_job_name and task_index == 0:
205208
tb_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
206209
tb_sock.bind(('', 0))
207210
tb_port = tb_sock.getsockname()[1]
@@ -223,7 +226,11 @@ def _mapfn(iter):
223226
raise Exception("Unable to find 'tensorboard' in: {}".format(search_path))
224227

225228
# launch tensorboard
226-
tb_proc = subprocess.Popen([pypath, tb_path, "--logdir=%s" % logdir, "--port=%d" % tb_port], env=os.environ)
229+
if version.parse(tf.__version__) >= version.parse('2.0.0'):
230+
tb_proc = subprocess.Popen([pypath, tb_path, "--reload_multifile=True", "--logdir=%s" % logdir, "--port=%d" % tb_port], env=os.environ)
231+
else:
232+
tb_proc = subprocess.Popen([pypath, tb_path, "--logdir=%s" % logdir, "--port=%d" % tb_port], env=os.environ)
233+
227234
tb_pid = tb_proc.pid
228235

229236
# check server to see if this task is being retried (i.e. already reserved)

tensorflowonspark/compat.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
11
# Copyright 2019 Yahoo Inc / Verizon Media
22
# Licensed under the terms of the Apache 2.0 license.
33
# Please see LICENSE file in the project root for terms.
4-
"""Helper functions to abstract API changes between TensorFlow versions."""
4+
"""Helper functions to abstract API changes between TensorFlow versions, intended for end-user TF code."""
55

66
import tensorflow as tf
7-
8-
TF_VERSION = tf.__version__
7+
from packaging import version
98

109

1110
def export_saved_model(model, export_dir, is_chief=False):
12-
if TF_VERSION == '2.0.0':
11+
if version.parse(tf.__version__) == version.parse('2.0.0'):
1312
if is_chief:
1413
tf.keras.experimental.export_saved_model(model, export_dir)
1514
else:
1615
model.save(export_dir, save_format='tf')
1716

1817

1918
def disable_auto_shard(options):
20-
if TF_VERSION == '2.0.0':
19+
if version.parse(tf.__version__) == version.parse('2.0.0'):
2120
options.experimental_distribute.auto_shard = False
2221
else:
2322
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF

0 commit comments

Comments
 (0)