Skip to content

Commit 2417f9c

Browse files
author
Lee Yang
committed
add --reload_multifile option to tb launch; use version.parse() in compat lib
1 parent 372a917 commit 2417f9c

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

tensorflowonspark/TFSparkNode.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def run(fn, tf_args, cluster_meta, tensorboard, log_dir, queues, background):
137137
"""
138138
def _mapfn(iter):
139139
import tensorflow as tf
140+
from packaging import version
140141

141142
# Note: consuming the input iterator helps Pyspark re-use this worker,
142143
for i in iter:
@@ -225,7 +226,11 @@ def _mapfn(iter):
225226
raise Exception("Unable to find 'tensorboard' in: {}".format(search_path))
226227

227228
# launch tensorboard
228-
tb_proc = subprocess.Popen([pypath, tb_path, "--logdir=%s" % logdir, "--port=%d" % tb_port], env=os.environ)
229+
if version.parse(tf.__version__) >= version.parse('2.0.0'):
230+
tb_proc = subprocess.Popen([pypath, tb_path, "--reload_multifile=True", "--logdir=%s" % logdir, "--port=%d" % tb_port], env=os.environ)
231+
else:
232+
tb_proc = subprocess.Popen([pypath, tb_path, "--logdir=%s" % logdir, "--port=%d" % tb_port], env=os.environ)
233+
229234
tb_pid = tb_proc.pid
230235

231236
# check server to see if this task is being retried (i.e. already reserved)

tensorflowonspark/compat.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
11
# Copyright 2019 Yahoo Inc / Verizon Media
22
# Licensed under the terms of the Apache 2.0 license.
33
# Please see LICENSE file in the project root for terms.
4-
"""Helper functions to abstract API changes between TensorFlow versions."""
4+
"""Helper functions to abstract API changes between TensorFlow versions, intended for end-user TF code."""
55

66
import tensorflow as tf
7-
8-
TF_VERSION = tf.__version__
7+
from packaging import version
98

109

1110
def export_saved_model(model, export_dir, is_chief=False):
12-
if TF_VERSION == '2.0.0':
11+
if version.parse(tf.__version__) == version.parse('2.0.0'):
1312
if is_chief:
1413
tf.keras.experimental.export_saved_model(model, export_dir)
1514
else:
1615
model.save(export_dir, save_format='tf')
1716

1817

1918
def disable_auto_shard(options):
20-
if TF_VERSION == '2.0.0':
19+
if version.parse(tf.__version__) == version.parse('2.0.0'):
2120
options.experimental_distribute.auto_shard = False
2221
else:
2322
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF

0 commit comments

Comments
 (0)