Skip to content

Commit 5fc6aa7

Browse files
authored
Merge pull request #367 from AvihayTsayeg/master
Enable configuring TFoS server IP and PORT
2 parents 0295f46 + 3af34da commit 5fc6aa7

File tree

5 files changed

+49
-8
lines changed

5 files changed

+49
-8
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,5 @@ docs/.doctrees
1010
target
1111
test-data
1212
dependency-reduced-pom.xml
13+
venv
14+
.idea

.travis.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ matrix:
1313
- export PYTHONPATH=$(pwd)
1414
install:
1515
- pip install -r requirements.txt
16+
- pip install mock
1617
script:
1718
- test/run_tests.sh
1819
- language: python

tensorflowonspark/pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from pyspark.sql import Row, SparkSession
2424

2525
import tensorflow as tf
26-
from tensorflow.contrib.saved_model.python.saved_model import reader, signature_def_utils
26+
from tensorflow.contrib.saved_model.python.saved_model import reader
2727
from tensorflow.python.saved_model import loader
2828
from . import TFCluster, gpu_info, dfutil
2929

@@ -503,7 +503,7 @@ def _run_model(iterator, args, tf_args):
503503
assert args.export_dir, "Inferencing with signature_def_key requires --export_dir argument"
504504
logging.info("===== loading meta_graph_def for tag_set ({0}) from saved_model: {1}".format(args.tag_set, args.export_dir))
505505
meta_graph_def = get_meta_graph_def(args.export_dir, args.tag_set)
506-
signature = signature_def_utils.get_signature_def_by_key(meta_graph_def, args.signature_def_key)
506+
signature = meta_graph_def.signature_def[args.signature_def_key]
507507
logging.debug("signature: {}".format(signature))
508508
inputs_tensor_info = signature.inputs
509509
logging.debug("inputs_tensor_info: {0}".format(inputs_tensor_info))

tensorflowonspark/reservation.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from __future__ import print_function
1010

1111
import logging
12+
import os
1213
import pickle
1314
import select
1415
import socket
@@ -19,6 +20,8 @@
1920

2021
from . import util
2122

23+
TFOS_SERVER_PORT = "TFOS_SERVER_PORT"
24+
TFOS_SERVER_HOST = "TFOS_SERVER_HOST"
2225
BUFSIZE = 1024
2326
MAX_RETRIES = 3
2427

@@ -146,13 +149,10 @@ def start(self):
146149
Returns:
147150
address of the Server as a tuple of (host, port)
148151
"""
149-
server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
150-
server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
151-
server_sock.bind(('', 0))
152-
server_sock.listen(10)
152+
server_sock = self.start_listening_socket()
153153

154154
# hostname may not be resolvable but IP address probably will be
155-
host = util.get_ip_address()
155+
host = self.get_server_ip()
156156
port = server_sock.getsockname()[1]
157157
addr = (host, port)
158158
logging.info("listening for reservations at {0}".format(addr))
@@ -185,6 +185,18 @@ def _listen(self, sock):
185185

186186
return addr
187187

188+
def get_server_ip(self):
189+
return os.getenv(TFOS_SERVER_HOST) if os.getenv(TFOS_SERVER_HOST) else util.get_ip_address()
190+
191+
192+
def start_listening_socket(self):
193+
port_number = int(os.getenv(TFOS_SERVER_PORT)) if os.getenv(TFOS_SERVER_PORT) else 0
194+
server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
195+
server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
196+
server_sock.bind(('', port_number))
197+
server_sock.listen(10)
198+
return server_sock
199+
188200
def stop(self):
189201
"""Stop the Server's socket listener."""
190202
self.done = True

test/test_reservation.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1+
import os
2+
import sys
13
import threading
24
import time
35
import unittest
46

7+
from tensorflowonspark import util
58
from tensorflowonspark.reservation import Reservations, Server, Client
69

10+
if sys.version_info >= (3, 3):
11+
from unittest import mock
12+
else:
13+
import mock
714

815
class ReservationTest(unittest.TestCase):
916
def test_reservation_class(self):
@@ -48,6 +55,25 @@ def test_reservation_server(self):
4855
time.sleep(1)
4956
self.assertEqual(s.done, True)
5057

58+
def test_reservation_enviroment_exists_get_server_ip_return_environment_value(self):
59+
tfso_server = Server(5)
60+
with mock.patch.dict(os.environ,{'TFOS_SERVER_HOST':'my_host_ip'}):
61+
assert tfso_server.get_server_ip() == "my_host_ip"
62+
63+
def test_reservation_enviroment_not_exists_get_server_ip_return_actual_host_ip(self):
64+
tfso_server = Server(5)
65+
assert tfso_server.get_server_ip() == util.get_ip_address()
66+
67+
def test_reservation_enviroment_exists_start_listening_socket_return_socket_listening_to_environment_port_value(self):
68+
tfso_server = Server(1)
69+
with mock.patch.dict(os.environ, {'TFOS_SERVER_PORT': '9999'}):
70+
assert tfso_server.start_listening_socket().getsockname()[1] == 9999
71+
72+
def test_reservation_enviroment_not_exists_start_listening_socket_return_socket(self):
73+
tfso_server = Server(1)
74+
print(tfso_server.start_listening_socket().getsockname()[1])
75+
assert type(tfso_server.start_listening_socket().getsockname()[1]) == int
76+
5177
def test_reservation_server_multi(self):
5278
"""Test reservation server, expecting multiple reservations"""
5379
num_clients = 4
@@ -85,4 +111,4 @@ def reserve(num):
85111

86112

87113
if __name__ == '__main__':
88-
unittest.main()
114+
unittest.main()

0 commit comments

Comments
 (0)