Skip to content

Commit 147882b

Browse files
authored
Merge pull request #500 from yahoo/leewyang_port_config
configurable ports for TF and TensorBoard
2 parents 2f673d6 + eeb717f commit 147882b

File tree

2 files changed

+30
-20
lines changed

2 files changed

+30
-20
lines changed

tensorflowonspark/TFSparkNode.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,15 @@ def _mapfn(iter):
209209
tb_pid = 0
210210
tb_port = 0
211211
if tensorboard and job_name == tb_job_name and task_index == 0:
212-
tb_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
213-
tb_sock.bind(('', 0))
214-
tb_port = tb_sock.getsockname()[1]
215-
tb_sock.close()
212+
if 'TENSORBOARD_PORT' in os.environ:
213+
# use port defined in env var
214+
tb_port = int(os.environ['TENSORBOARD_PORT'])
215+
else:
216+
# otherwise, find a free port
217+
tb_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
218+
tb_sock.bind(('', 0))
219+
tb_port = tb_sock.getsockname()[1]
220+
tb_sock.close()
216221
logdir = log_dir if log_dir else "tensorboard_%d" % executor_id
217222

218223
# search for tensorboard in python/bin, PATH, and PYTHONPATH
@@ -250,11 +255,15 @@ def _mapfn(iter):
250255

251256
# if not already done, register everything we need to set up the cluster
252257
if node_meta is None:
253-
# first, find a free port for TF
254-
tmp_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
255-
tmp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
256-
tmp_sock.bind(('', port))
257-
port = tmp_sock.getsockname()[1]
258+
if 'TENSORFLOW_PORT' in os.environ:
259+
# use port defined in env var
260+
port = int(os.environ['TENSORFLOW_PORT'])
261+
else:
262+
# otherwise, find a free port
263+
tmp_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
264+
tmp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
265+
tmp_sock.bind(('', port))
266+
port = tmp_sock.getsockname()[1]
258267

259268
node_meta = {
260269
'executor_id': executor_id,

test/test_reservation.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
else:
1313
import mock
1414

15+
1516
class ReservationTest(unittest.TestCase):
1617
def test_reservation_class(self):
1718
"""Test core reservation class, expecting 2 reservations"""
@@ -56,23 +57,23 @@ def test_reservation_server(self):
5657
self.assertEqual(s.done, True)
5758

5859
def test_reservation_enviroment_exists_get_server_ip_return_environment_value(self):
59-
tfso_server = Server(5)
60-
with mock.patch.dict(os.environ,{'TFOS_SERVER_HOST':'my_host_ip'}):
61-
assert tfso_server.get_server_ip() == "my_host_ip"
60+
tfos_server = Server(5)
61+
with mock.patch.dict(os.environ, {'TFOS_SERVER_HOST': 'my_host_ip'}):
62+
assert tfos_server.get_server_ip() == "my_host_ip"
6263

6364
def test_reservation_enviroment_not_exists_get_server_ip_return_actual_host_ip(self):
64-
tfso_server = Server(5)
65-
assert tfso_server.get_server_ip() == util.get_ip_address()
65+
tfos_server = Server(5)
66+
assert tfos_server.get_server_ip() == util.get_ip_address()
6667

6768
def test_reservation_enviroment_exists_start_listening_socket_return_socket_listening_to_environment_port_value(self):
68-
tfso_server = Server(1)
69+
tfos_server = Server(1)
6970
with mock.patch.dict(os.environ, {'TFOS_SERVER_PORT': '9999'}):
70-
assert tfso_server.start_listening_socket().getsockname()[1] == 9999
71+
assert tfos_server.start_listening_socket().getsockname()[1] == 9999
7172

7273
def test_reservation_enviroment_not_exists_start_listening_socket_return_socket(self):
73-
tfso_server = Server(1)
74-
print(tfso_server.start_listening_socket().getsockname()[1])
75-
assert type(tfso_server.start_listening_socket().getsockname()[1]) == int
74+
tfos_server = Server(1)
75+
print(tfos_server.start_listening_socket().getsockname()[1])
76+
assert type(tfos_server.start_listening_socket().getsockname()[1]) == int
7677

7778
def test_reservation_server_multi(self):
7879
"""Test reservation server, expecting multiple reservations"""
@@ -111,4 +112,4 @@ def reserve(num):
111112

112113

113114
if __name__ == '__main__':
114-
unittest.main()
115+
unittest.main()

0 commit comments

Comments
 (0)