Skip to content

Commit 10e2473

Browse files
authored
Merge pull request #488 from yahoo/leewyang_compat_gpu
TF2.1 release compatibility
2 parents 26ff94d + 149a703 commit 10e2473

File tree

5 files changed

+15
-7
lines changed

5 files changed

+15
-7
lines changed

tensorflowonspark/TFNode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from packaging import version
2121
from six.moves.queue import Empty
22-
from . import marker
22+
from . import compat, marker
2323

2424
logger = logging.getLogger(__name__)
2525

@@ -90,7 +90,7 @@ def start_cluster_server(ctx, num_gpus=1, rdma=False):
9090
cluster_spec = ctx.cluster_spec
9191
logging.info("{0}: Cluster spec: {1}".format(ctx.worker_num, cluster_spec))
9292

93-
if tf.test.is_built_with_cuda() and num_gpus > 0:
93+
if compat.is_gpu_available() and num_gpus > 0:
9494
# compute my index relative to other nodes placed on the same host (for GPU allocation)
9595
my_addr = cluster_spec[ctx.job_name][ctx.task_index]
9696
my_host = my_addr.split(':')[0]

tensorflowonspark/TFSparkNode.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from . import TFManager
2525
from . import TFNode
26+
from . import compat
2627
from . import gpu_info
2728
from . import marker
2829
from . import reservation
@@ -144,7 +145,7 @@ def _mapfn(iter):
144145
executor_id = i
145146

146147
# check that there are enough available GPUs (if using tensorflow-gpu) before committing reservation on this node
147-
if tf.test.is_built_with_cuda():
148+
if compat.is_gpu_available():
148149
num_gpus = tf_args.num_gpus if 'num_gpus' in tf_args else 1
149150
gpus_to_use = gpu_info.get_gpus(num_gpus)
150151

@@ -295,7 +296,7 @@ def _mapfn(iter):
295296
os.environ['TF_CONFIG'] = tf_config
296297

297298
# reserve GPU(s) again, just before launching TF process (in case situation has changed)
298-
if tf.test.is_built_with_cuda():
299+
if compat.is_gpu_available():
299300
# compute my index relative to other nodes on the same host (for GPU allocation)
300301
my_addr = cluster_spec[job_name][task_index]
301302
my_host = my_addr.split(':')[0]

tensorflowonspark/compat.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,10 @@ def disable_auto_shard(options):
2020
options.experimental_distribute.auto_shard = False
2121
else:
2222
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
23+
24+
25+
def is_gpu_available():
26+
if version.parse(tf.__version__) < version.parse('2.1.0'):
27+
return tf.test.is_built_with_cuda()
28+
else:
29+
return len(tf.config.list_logical_devices('GPU')) > 0

tensorflowonspark/util.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,13 @@
1313
import subprocess
1414
import errno
1515
from socket import error as socket_error
16-
from . import gpu_info
16+
from . import compat, gpu_info
1717

1818
logger = logging.getLogger(__name__)
1919

2020

2121
def single_node_env(num_gpus=1, worker_index=-1, nodes=[]):
2222
"""Setup environment variables for Hadoop compatibility and GPU allocation"""
23-
import tensorflow as tf
2423
# ensure expanded CLASSPATH w/o glob characters (required for Spark 2.1 + JNI)
2524
if 'HADOOP_PREFIX' in os.environ and 'TFOS_CLASSPATH_UPDATED' not in os.environ:
2625
classpath = os.environ['CLASSPATH']
@@ -29,7 +28,7 @@ def single_node_env(num_gpus=1, worker_index=-1, nodes=[]):
2928
os.environ['CLASSPATH'] = classpath + os.pathsep + hadoop_classpath
3029
os.environ['TFOS_CLASSPATH_UPDATED'] = '1'
3130

32-
if tf.test.is_built_with_cuda() and num_gpus > 0:
31+
if compat.is_gpu_available() and num_gpus > 0:
3332
# reserve GPU(s), if requested
3433
if worker_index >= 0 and len(nodes) > 0:
3534
# compute my index relative to other nodes on the same host, if known

test/test_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def _spark_train(args, ctx):
9494
import tensorflow as tf
9595
from tensorflowonspark import TFNode
9696

97+
tf.compat.v1.disable_eager_execution()
9798
tf.compat.v1.reset_default_graph()
9899
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
99100

0 commit comments

Comments
 (0)