Skip to content

Commit df468ff

Browse files
authored
Merge pull request #473 from yahoo/leewyang_parallel
add TFParallel.run() API
2 parents dca5538 + 531000d commit df468ff

File tree

9 files changed

+99
-25
lines changed

9 files changed

+99
-25
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
tensorflowonspark\.TFParallel module
2+
===================================
3+
4+
.. automodule:: tensorflowonspark.TFParallel
5+
:members:
6+
:undoc-members:
7+
:show-inheritance:

docs/source/tensorflowonspark.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Submodules
1414
tensorflowonspark.TFCluster
1515
tensorflowonspark.TFManager
1616
tensorflowonspark.TFNode
17+
tensorflowonspark.TFParallel
1718
tensorflowonspark.TFSparkNode
1819
tensorflowonspark.dfutil
1920
tensorflowonspark.gpu_info

examples/mnist/keras/mnist_inference.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,7 @@
2121
import tensorflow as tf
2222

2323

24-
def inference(it, num_workers, args):
25-
from tensorflowonspark import util
26-
27-
# consume worker number from RDD partition iterator
28-
for i in it:
29-
worker_num = i
30-
print("worker_num: {}".format(i))
31-
32-
# setup env for single-node TF
33-
util.single_node_env()
24+
def inference(args, ctx):
3425

3526
# load saved_model
3627
saved_model = tf.saved_model.load(args.export_dir, tags='serve')
@@ -48,14 +39,14 @@ def parse_tfr(example_proto):
4839

4940
# define a new tf.data.Dataset (for inferencing)
5041
ds = tf.data.Dataset.list_files("{}/part-*".format(args.images_labels))
51-
ds = ds.shard(num_workers, worker_num)
42+
ds = ds.shard(ctx.num_workers, ctx.worker_num)
5243
ds = ds.interleave(tf.data.TFRecordDataset)
5344
ds = ds.map(parse_tfr)
5445
ds = ds.batch(10)
5546

5647
# create an output file per spark worker for the predictions
5748
tf.io.gfile.makedirs(args.output)
58-
output_file = tf.io.gfile.GFile("{}/part-{:05d}".format(args.output, worker_num), mode='w')
49+
output_file = tf.io.gfile.GFile("{}/part-{:05d}".format(args.output, ctx.worker_num), mode='w')
5950

6051
for batch in ds:
6152
predictions = predict(conv2d_input=batch[0])
@@ -70,6 +61,7 @@ def parse_tfr(example_proto):
7061
if __name__ == '__main__':
7162
from pyspark.context import SparkContext
7263
from pyspark.conf import SparkConf
64+
from tensorflowonspark import TFParallel
7365

7466
sc = SparkContext(conf=SparkConf().setAppName("mnist_inference"))
7567
executors = sc._conf.get("spark.executor.instances")
@@ -83,7 +75,5 @@ def parse_tfr(example_proto):
8375
args, _ = parser.parse_known_args()
8476
print("args: {}".format(args))
8577

86-
# Not using TFCluster... just running single-node TF instances on each executor
87-
nodes = list(range(args.cluster_size))
88-
nodeRDD = sc.parallelize(list(range(args.cluster_size)), args.cluster_size)
89-
nodeRDD.foreachPartition(lambda worker_num: inference(worker_num, args.cluster_size, args))
78+
# Running single-node TF instances on each executor
79+
TFParallel.run(sc, inference, args, args.cluster_size)

examples/resnet/README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@ Original Source: https://github.com/tensorflow/models/tree/master/official/visio
44

55
This code is based on the Image Classification model from the official [TensorFlow Models](https://github.com/tensorflow/models) repository. This example already supports different forms of distribution via the `DistributionStrategy` API, so there isn't much additional work to convert it to TensorFlowOnSpark.
66

7-
Notes:
7+
Notes:
88
- This example assumes that Spark, TensorFlow, and TensorFlowOnSpark are already installed.
99
- For simplicity, this just uses a single-node Spark Standalone installation.
1010

1111
#### Run the Single-Node Application
1212

13-
First, make sure that you can run the example per the [original instructions](https://github.com/tensorflow/models/tree/68c3c65596b8fc624be15aef6eac3dc8952cbf23/official/vision/image_classification). For now, we'll just use the CIFAR-10 dataset. After cloning the [tensorflow/models](https://github.com/tensorflow/models) repository and downloading the dataset, you should be able to run the training as follows:
13+
First, make sure that you can run the example per the [original instructions](https://github.com/tensorflow/models/tree/68c3c65596b8fc624be15aef6eac3dc8952cbf23/official/vision/image_classification). For now, we'll just use the CIFAR-10 dataset. After cloning the [tensorflow/models](https://github.com/tensorflow/models) repository (checking out the `v2.0` tag with `git checkout v2.0`), and downloading the dataset, you should be able to run the training as follows:
1414
```
15+
# Note: these instructions have been tested with the `v2.0` tag of tensorflow/models.
16+
1517
export TENSORFLOW_MODELS=/path/to/tensorflow/models
1618
export CIFAR_DATA=/path/to/cifar
1719
export PYTHONPATH=${PYTHONPATH}:${TENSORFLOW_MODELS}

tensorflowonspark/TFNode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
logger = logging.getLogger(__name__)
2323

24+
2425
def hdfs_path(ctx, path):
2526
"""Convenience function to create a Tensorflow-compatible absolute HDFS path from relative paths
2627

tensorflowonspark/TFParallel.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright 2019 Yahoo Inc / Verizon Media
2+
# Licensed under the terms of the Apache 2.0 license.
3+
# Please see LICENSE file in the project root for terms.
4+
5+
from __future__ import absolute_import
6+
from __future__ import division
7+
from __future__ import nested_scopes
8+
from __future__ import print_function
9+
10+
import logging
11+
from . import TFSparkNode
12+
from . import gpu_info, util
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
def run(sc, map_fn, tf_args, num_executors):
18+
"""Runs the user map_fn as parallel, independent instances of TF on the Spark executors.
19+
20+
Args:
21+
:sc: SparkContext
22+
:map_fun: user-supplied TensorFlow "main" function
23+
:tf_args: ``argparse`` args, or command-line ``ARGV``. These will be passed to the ``map_fun``.
24+
:num_executors: number of Spark executors. This should match your Spark job's ``--num_executors``.
25+
26+
Returns:
27+
None
28+
"""
29+
30+
# get default filesystem from spark
31+
defaultFS = sc._jsc.hadoopConfiguration().get("fs.defaultFS")
32+
# strip trailing "root" slash from "file:///" to be consistent w/ "hdfs://..."
33+
if defaultFS.startswith("file://") and len(defaultFS) > 7 and defaultFS.endswith("/"):
34+
defaultFS = defaultFS[:-1]
35+
36+
def _run(it):
37+
from pyspark import BarrierTaskContext
38+
39+
for i in it:
40+
worker_num = i
41+
42+
# use BarrierTaskContext to get placement of all nodes
43+
ctx = BarrierTaskContext.get()
44+
tasks = ctx.getTaskInfos()
45+
nodes = [t.address for t in tasks]
46+
47+
# use the placement info to help allocate GPUs
48+
num_gpus = tf_args.num_gpus if 'num_gpus' in tf_args else 1
49+
util.single_node_env(num_gpus=num_gpus, worker_index=worker_num, nodes=nodes)
50+
51+
# run the user map_fn
52+
ctx = TFSparkNode.TFNodeContext()
53+
ctx.defaultFS = defaultFS
54+
ctx.worker_num = worker_num
55+
ctx.executor_id = worker_num
56+
ctx.num_workers = len(nodes)
57+
58+
map_fn(tf_args, ctx)
59+
60+
# return a dummy iterator (since we have to use mapPartitions)
61+
return [0]
62+
63+
nodeRDD = sc.parallelize(list(range(num_executors)), num_executors)
64+
nodeRDD.barrier().mapPartitions(_run).collect()

tensorflowonspark/TFSparkNode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class TFNodeContext:
4646
:working_dir: the current working directory for local filesystems, or YARN containers.
4747
:mgr: TFManager instance for this Python worker.
4848
"""
49-
def __init__(self, executor_id, job_name, task_index, cluster_spec, defaultFS, working_dir, mgr):
49+
def __init__(self, executor_id=0, job_name='', task_index=0, cluster_spec={}, defaultFS='file://', working_dir='.', mgr=None):
5050
self.worker_num = executor_id # for backwards-compatibility
5151
self.executor_id = executor_id
5252
self.job_name = job_name

tensorflowonspark/reservation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,8 @@ def _listen(self, sock):
190190
def get_server_ip(self):
191191
return os.getenv(TFOS_SERVER_HOST) if os.getenv(TFOS_SERVER_HOST) else util.get_ip_address()
192192

193-
194193
def start_listening_socket(self):
195-
port_number = int(os.getenv(TFOS_SERVER_PORT)) if os.getenv(TFOS_SERVER_PORT) else 0
194+
port_number = int(os.getenv(TFOS_SERVER_PORT)) if os.getenv(TFOS_SERVER_PORT) else 0
196195
server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
197196
server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
198197
server_sock.bind(('', port_number))

tensorflowonspark/util.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
logger = logging.getLogger(__name__)
1919

2020

21-
def single_node_env(num_gpus=1):
21+
def single_node_env(num_gpus=1, worker_index=-1, nodes=[]):
2222
"""Setup environment variables for Hadoop compatibility and GPU allocation"""
2323
import tensorflow as tf
2424
# ensure expanded CLASSPATH w/o glob characters (required for Spark 2.1 + JNI)
@@ -29,9 +29,19 @@ def single_node_env(num_gpus=1):
2929
os.environ['CLASSPATH'] = classpath + os.pathsep + hadoop_classpath
3030
os.environ['TFOS_CLASSPATH_UPDATED'] = '1'
3131

32-
# reserve GPU, if requested
33-
if tf.test.is_built_with_cuda():
34-
gpus_to_use = gpu_info.get_gpus(num_gpus)
32+
if tf.test.is_built_with_cuda() and num_gpus > 0:
33+
# reserve GPU(s), if requested
34+
if worker_index >= 0 and len(nodes) > 0:
35+
# compute my index relative to other nodes on the same host, if known
36+
my_addr = nodes[worker_index]
37+
my_host = my_addr.split(':')[0]
38+
local_peers = [n for n in nodes if n.startswith(my_host)]
39+
my_index = local_peers.index(my_addr)
40+
else:
41+
# otherwise, just use global worker index
42+
my_index = worker_index
43+
44+
gpus_to_use = gpu_info.get_gpus(num_gpus, my_index)
3545
logger.info("Using gpu(s): {0}".format(gpus_to_use))
3646
os.environ['CUDA_VISIBLE_DEVICES'] = gpus_to_use
3747
else:

0 commit comments

Comments
 (0)