Skip to content

Commit 7179a8d

Browse files
authored
minor patches (#585)
* minor patches Co-authored-by: Lee Yang <[email protected]>
1 parent 299cfec commit 7179a8d

File tree

6 files changed

+19
-8
lines changed

6 files changed

+19
-8
lines changed

tensorflowonspark/TFNode.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
from . import compat, marker
2424

2525
logger = logging.getLogger(__name__)
26-
TF_VERSION = pkg_resources.get_distribution('tensorflow').version
26+
try:
27+
TF_VERSION = pkg_resources.get_distribution('tensorflow').version
28+
except pkg_resources.DistributionNotFound:
29+
TF_VERSION = pkg_resources.get_distribution('tensorflow-cpu').version
2730

2831

2932
def hdfs_path(ctx, path):

tensorflowonspark/TFParallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _run(it):
4747
nodes = [t.address for t in tasks]
4848
num_workers = len(nodes)
4949
else:
50-
nodes = None
50+
nodes = []
5151
num_workers = num_executors
5252

5353
# use the placement info to help allocate GPUs

tensorflowonspark/TFSparkNode.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@
3131
from . import util
3232

3333
logger = logging.getLogger(__name__)
34-
TF_VERSION = pkg_resources.get_distribution('tensorflow').version
34+
try:
35+
TF_VERSION = pkg_resources.get_distribution('tensorflow').version
36+
except pkg_resources.DistributionNotFound:
37+
TF_VERSION = pkg_resources.get_distribution('tensorflow-cpu').version
3538

3639

3740
def _has_spark_resource_api():
@@ -502,7 +505,7 @@ def _train(iter):
502505
joinThr = Thread(target=queue.join)
503506
joinThr.start()
504507
timeout = feed_timeout
505-
while (joinThr.isAlive()):
508+
while (joinThr.is_alive()):
506509
if (not equeue.empty()):
507510
e_str = equeue.get()
508511
raise Exception("Exception in worker:\n" + e_str)
@@ -570,7 +573,7 @@ def _inference(iter):
570573
joinThr = Thread(target=queue_in.join)
571574
joinThr.start()
572575
timeout = feed_timeout
573-
while (joinThr.isAlive()):
576+
while (joinThr.is_alive()):
574577
if (not equeue.empty()):
575578
e_str = equeue.get()
576579
raise Exception("Exception in worker:\n" + e_str)

tensorflowonspark/pipeline.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@
3131

3232

3333
logger = logging.getLogger(__name__)
34-
TF_VERSION = pkg_resources.get_distribution('tensorflow').version
34+
try:
35+
TF_VERSION = pkg_resources.get_distribution('tensorflow').version
36+
except pkg_resources.DistributionNotFound:
37+
TF_VERSION = pkg_resources.get_distribution('tensorflow-cpu').version
3538

3639

3740
# TensorFlowOnSpark Params

tensorflowonspark/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def single_node_env(num_gpus=1, worker_index=-1, nodes=[]):
3030

3131
if gpu_info.is_gpu_available() and num_gpus > 0:
3232
# reserve GPU(s), if requested
33-
if worker_index >= 0 and len(nodes) > 0:
33+
if worker_index >= 0 and nodes and len(nodes) > 0:
3434
# compute my index relative to other nodes on the same host, if known
3535
my_addr = nodes[worker_index]
3636
my_host = my_addr.split(':')[0]

tests/test_TFSparkNode.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,10 @@ def fn(args, ctx):
4646
map_fn = TFSparkNode.run(fn, tf_args, self.cluster_meta, self.tensorboard, self.log_dir, self.queues, self.background)
4747
map_fn([0])
4848

49-
def test_gpu_unavailable(self):
49+
@patch('tensorflowonspark.gpu_info.is_gpu_available')
50+
def test_gpu_unavailable(self, mock_available):
5051
"""Request GPU with no GPUs available, expecting an exception"""
52+
mock_available.return_value = False
5153
self.parser.add_argument("--num_gpus", help="number of gpus to use", type=int)
5254
tf_args = self.parser.parse_args(["--num_gpus", "1"])
5355

0 commit comments

Comments
 (0)