Skip to content

Commit 92a7877

Browse files
authored
Merge pull request #493 from yahoo/leewyang_spark3_gpu
Spark3 GPU allocation
2 parents 64d6f55 + 4a36455 commit 92a7877

File tree

2 files changed

+35
-13
lines changed

2 files changed

+35
-13
lines changed

tensorflowonspark/TFNode.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ def start_cluster_server(ctx, num_gpus=1, rdma=False):
137137
raise Exception("Failed to allocate GPU")
138138
else:
139139
# CPU
140+
import tensorflow as tf
141+
140142
os.environ['CUDA_VISIBLE_DEVICES'] = ''
141143
logging.info("{0}: Using CPU".format(ctx.worker_num))
142144

tensorflowonspark/TFSparkNode.py

100755100644
Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -139,14 +139,18 @@ def run(fn, tf_args, cluster_meta, tensorboard, log_dir, queues, background):
139139
A nodeRDD.mapPartitions() function.
140140
"""
141141
def _mapfn(iter):
142+
import pyspark
143+
142144
# Note: consuming the input iterator helps Pyspark re-use this worker,
143145
for i in iter:
144146
executor_id = i
145147

146148
# check that there are enough available GPUs (if using tensorflow-gpu) before committing reservation on this node
147-
if gpu_info.is_gpu_available():
148-
num_gpus = tf_args.num_gpus if 'num_gpus' in tf_args else 1
149-
gpus_to_use = gpu_info.get_gpus(num_gpus)
149+
# note: for Spark 3+ w/ GPU allocation, the required number of GPUs should be guaranteed by the resource manager
150+
if version.parse(pyspark.__version__).base_version < version.parse('3.0.0').base_version:
151+
if gpu_info.is_gpu_available():
152+
num_gpus = tf_args.num_gpus if 'num_gpus' in tf_args else 1
153+
gpus_to_use = gpu_info.get_gpus(num_gpus)
150154

151155
# assign TF job/task based on provided cluster_spec template (or use default/null values)
152156
job_name = 'default'
@@ -295,18 +299,34 @@ def _mapfn(iter):
295299
os.environ['TF_CONFIG'] = tf_config
296300

297301
# reserve GPU(s) again, just before launching TF process (in case situation has changed)
302+
# and setup CUDA_VISIBLE_DEVICES accordingly
298303
if gpu_info.is_gpu_available():
299-
# compute my index relative to other nodes on the same host (for GPU allocation)
300-
my_addr = cluster_spec[job_name][task_index]
301-
my_host = my_addr.split(':')[0]
302-
flattened = [v for sublist in cluster_spec.values() for v in sublist]
303-
local_peers = [p for p in flattened if p.startswith(my_host)]
304-
my_index = local_peers.index(my_addr)
305-
306-
num_gpus = tf_args.num_gpus if 'num_gpus' in tf_args else 1
307-
gpus_to_use = gpu_info.get_gpus(num_gpus, my_index)
304+
305+
gpus_to_use = None
306+
# For Spark 3+, try to get GPU resources from TaskContext first
307+
if version.parse(pyspark.__version__).base_version >= version.parse("3.0.0").base_version:
308+
from pyspark import TaskContext
309+
context = TaskContext()
310+
if 'gpu' in context.resources():
311+
# use ALL GPUs assigned by resource manager
312+
gpus = context.resources()['gpu'].addresses
313+
num_gpus = len(gpus)
314+
gpus_to_use = ','.join(gpus)
315+
316+
if not gpus_to_use:
317+
# compute my index relative to other nodes on the same host (for GPU allocation)
318+
my_addr = cluster_spec[job_name][task_index]
319+
my_host = my_addr.split(':')[0]
320+
flattened = [v for sublist in cluster_spec.values() for v in sublist]
321+
local_peers = [p for p in flattened if p.startswith(my_host)]
322+
my_index = local_peers.index(my_addr)
323+
324+
# default to one GPU if not specified explicitly
325+
num_gpus = tf_args.num_gpus if 'num_gpus' in tf_args else 1
326+
gpus_to_use = gpu_info.get_gpus(num_gpus, my_index)
327+
308328
gpu_str = "GPUs" if num_gpus > 1 else "GPU"
309-
logger.debug("Requested {} {}, setting CUDA_VISIBLE_DEVICES={}".format(num_gpus, gpu_str, gpus_to_use))
329+
logger.info("Requested {} {}, setting CUDA_VISIBLE_DEVICES={}".format(num_gpus, gpu_str, gpus_to_use))
310330
os.environ['CUDA_VISIBLE_DEVICES'] = gpus_to_use
311331

312332
# create a context object to hold metadata for TF

0 commit comments

Comments
 (0)