@@ -139,14 +139,18 @@ def run(fn, tf_args, cluster_meta, tensorboard, log_dir, queues, background):
139139 A nodeRDD.mapPartitions() function.
140140 """
141141 def _mapfn (iter ):
142+ import pyspark
143+
142144 # Note: consuming the input iterator helps Pyspark re-use this worker,
143145 for i in iter :
144146 executor_id = i
145147
146148 # check that there are enough available GPUs (if using tensorflow-gpu) before committing reservation on this node
147- if gpu_info .is_gpu_available ():
148- num_gpus = tf_args .num_gpus if 'num_gpus' in tf_args else 1
149- gpus_to_use = gpu_info .get_gpus (num_gpus )
149+ # note: for Spark 3+ w/ GPU allocation, the required number of GPUs should be guaranteed by the resource manager
150+ if version .parse (pyspark .__version__ ).base_version < version .parse ('3.0.0' ).base_version :
151+ if gpu_info .is_gpu_available ():
152+ num_gpus = tf_args .num_gpus if 'num_gpus' in tf_args else 1
153+ gpus_to_use = gpu_info .get_gpus (num_gpus )
150154
151155 # assign TF job/task based on provided cluster_spec template (or use default/null values)
152156 job_name = 'default'
@@ -295,18 +299,34 @@ def _mapfn(iter):
295299 os .environ ['TF_CONFIG' ] = tf_config
296300
297301 # reserve GPU(s) again, just before launching TF process (in case situation has changed)
302+ # and setup CUDA_VISIBLE_DEVICES accordingly
298303 if gpu_info .is_gpu_available ():
299- # compute my index relative to other nodes on the same host (for GPU allocation)
300- my_addr = cluster_spec [job_name ][task_index ]
301- my_host = my_addr .split (':' )[0 ]
302- flattened = [v for sublist in cluster_spec .values () for v in sublist ]
303- local_peers = [p for p in flattened if p .startswith (my_host )]
304- my_index = local_peers .index (my_addr )
305-
306- num_gpus = tf_args .num_gpus if 'num_gpus' in tf_args else 1
307- gpus_to_use = gpu_info .get_gpus (num_gpus , my_index )
304+
305+ gpus_to_use = None
306+ # For Spark 3+, try to get GPU resources from TaskContext first
307+ if version .parse (pyspark .__version__ ).base_version >= version .parse ("3.0.0" ).base_version :
308+ from pyspark import TaskContext
309+ context = TaskContext ()
310+ if 'gpu' in context .resources ():
311+ # use ALL GPUs assigned by resource manager
312+ gpus = context .resources ()['gpu' ].addresses
313+ num_gpus = len (gpus )
314+ gpus_to_use = ',' .join (gpus )
315+
316+ if not gpus_to_use :
317+ # compute my index relative to other nodes on the same host (for GPU allocation)
318+ my_addr = cluster_spec [job_name ][task_index ]
319+ my_host = my_addr .split (':' )[0 ]
320+ flattened = [v for sublist in cluster_spec .values () for v in sublist ]
321+ local_peers = [p for p in flattened if p .startswith (my_host )]
322+ my_index = local_peers .index (my_addr )
323+
324+ # default to one GPU if not specified explicitly
325+ num_gpus = tf_args .num_gpus if 'num_gpus' in tf_args else 1
326+ gpus_to_use = gpu_info .get_gpus (num_gpus , my_index )
327+
308328 gpu_str = "GPUs" if num_gpus > 1 else "GPU"
309- logger .debug ("Requested {} {}, setting CUDA_VISIBLE_DEVICES={}" .format (num_gpus , gpu_str , gpus_to_use ))
329+ logger .info ("Requested {} {}, setting CUDA_VISIBLE_DEVICES={}" .format (num_gpus , gpu_str , gpus_to_use ))
310330 os .environ ['CUDA_VISIBLE_DEVICES' ] = gpus_to_use
311331
312332 # create a context object to hold metadata for TF
0 commit comments