Skip to content

Commit f685adb

Browse files
authored
Merge pull request #254 from yahoo/leewyang_dataset
Fix mnist dataset example for TFRecords
2 parents a9d661f + 7cd8bfe commit f685adb

File tree

6 files changed

+80
-67
lines changed

6 files changed

+80
-67
lines changed

examples/mnist/mnist_data_setup.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,35 +8,39 @@
88

99
import numpy
1010
import tensorflow as tf
11-
from array import array
1211
from tensorflow.contrib.learn.python.learn.datasets import mnist
1312

13+
1414
def toTFExample(image, label):
1515
"""Serializes an image/label as a TFExample byte string"""
1616
example = tf.train.Example(
17-
features = tf.train.Features(
18-
feature = {
17+
features=tf.train.Features(
18+
feature={
1919
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=label.astype("int64"))),
2020
'image': tf.train.Feature(int64_list=tf.train.Int64List(value=image.astype("int64")))
2121
}
2222
)
2323
)
2424
return example.SerializeToString()
2525

26+
2627
def fromTFExample(bytestr):
2728
"""Deserializes a TFExample from a byte string"""
2829
example = tf.train.Example()
2930
example.ParseFromString(bytestr)
3031
return example
3132

33+
3234
def toCSV(vec):
3335
"""Converts a vector/array into a CSV string"""
3436
return ','.join([str(i) for i in vec])
3537

38+
3639
def fromCSV(s):
3740
"""Converts a CSV string to a vector/array"""
3841
return [float(x) for x in s.split(',') if len(s) > 0]
3942

43+
4044
def writeMNIST(sc, input_images, input_labels, output, format, num_partitions):
4145
"""Writes MNIST image/label vectors into parallelized files on HDFS"""
4246
# load MNIST gzip into memory
@@ -69,12 +73,12 @@ def writeMNIST(sc, input_images, input_labels, output, format, num_partitions):
6973
labelRDD.map(toCSV).saveAsTextFile(output_labels)
7074
elif format == "csv2":
7175
imageRDD.map(toCSV).zip(labelRDD).map(lambda x: str(x[1]) + "|" + x[0]).saveAsTextFile(output)
72-
else: # format == "tfr":
76+
else: # format == "tfr":
7377
tfRDD = imageRDD.zip(labelRDD).map(lambda x: (bytearray(toTFExample(x[0], x[1])), None))
7478
# requires: --jars tensorflow-hadoop-1.0-SNAPSHOT.jar
7579
tfRDD.saveAsNewAPIHadoopFile(output, "org.tensorflow.hadoop.io.TFRecordFileOutputFormat",
76-
keyClass="org.apache.hadoop.io.BytesWritable",
77-
valueClass="org.apache.hadoop.io.NullWritable")
80+
keyClass="org.apache.hadoop.io.BytesWritable",
81+
valueClass="org.apache.hadoop.io.NullWritable")
7882
# Note: this creates TFRecord files w/o requiring a custom Input/Output format
7983
# else: # format == "tfr":
8084
# def writeTFRecords(index, iter):
@@ -86,6 +90,7 @@ def writeMNIST(sc, input_images, input_labels, output, format, num_partitions):
8690
# tfRDD = imageRDD.zip(labelRDD).map(lambda x: toTFExample(x[0], x[1]))
8791
# tfRDD.mapPartitionsWithIndex(writeTFRecords).collect()
8892

93+
8994
def readMNIST(sc, output, format):
9095
"""Reads/verifies previously created output"""
9196

@@ -100,12 +105,12 @@ def readMNIST(sc, output, format):
100105
elif format == "csv":
101106
imageRDD = sc.textFile(output_images).map(fromCSV)
102107
labelRDD = sc.textFile(output_labels).map(fromCSV)
103-
else: # format.startswith("tf"):
108+
else: # format.startswith("tf"):
104109
# requires: --jars tensorflow-hadoop-1.0-SNAPSHOT.jar
105110
tfRDD = sc.newAPIHadoopFile(output, "org.tensorflow.hadoop.io.TFRecordFileInputFormat",
106-
keyClass="org.apache.hadoop.io.BytesWritable",
107-
valueClass="org.apache.hadoop.io.NullWritable")
108-
imageRDD = tfRDD.map(lambda x: fromTFExample(str(x[0])))
111+
keyClass="org.apache.hadoop.io.BytesWritable",
112+
valueClass="org.apache.hadoop.io.NullWritable")
113+
imageRDD = tfRDD.map(lambda x: fromTFExample(bytes(x[0])))
109114

110115
num_images = imageRDD.count()
111116
num_labels = labelRDD.count() if labelRDD is not None else num_images
@@ -114,21 +119,22 @@ def readMNIST(sc, output, format):
114119
print("num_labels: ", num_labels)
115120
print("samples: ", samples)
116121

122+
117123
if __name__ == "__main__":
118124
import argparse
119125

120126
from pyspark.context import SparkContext
121127
from pyspark.conf import SparkConf
122128

123129
parser = argparse.ArgumentParser()
124-
parser.add_argument("-f", "--format", help="output format", choices=["csv","csv2","pickle","tf","tfr"], default="csv")
125-
parser.add_argument("-n", "--num-partitions", help="Number of output partitions", type=int, default=10)
126-
parser.add_argument("-o", "--output", help="HDFS directory to save examples in parallelized format", default="mnist_data")
127-
parser.add_argument("-r", "--read", help="read previously saved examples", action="store_true")
128-
parser.add_argument("-v", "--verify", help="verify saved examples after writing", action="store_true")
130+
parser.add_argument("--format", help="output format", choices=["csv", "csv2", "pickle", "tf", "tfr"], default="csv")
131+
parser.add_argument("--num-partitions", help="Number of output partitions", type=int, default=10)
132+
parser.add_argument("--output", help="HDFS directory to save examples in parallelized format", default="mnist_data")
133+
parser.add_argument("--read", help="read previously saved examples", action="store_true")
134+
parser.add_argument("--verify", help="verify saved examples after writing", action="store_true")
129135

130136
args = parser.parse_args()
131-
print("args:",args)
137+
print("args:", args)
132138

133139
sc = SparkContext(conf=SparkConf().setAppName("mnist_parallelize"))
134140

@@ -139,4 +145,3 @@ def readMNIST(sc, output, format):
139145

140146
if args.read or args.verify:
141147
readMNIST(sc, args.output + "/train", args.format)
142-

examples/mnist/spark/mnist_spark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def toNumpy(bytestr):
5454
label = numpy.array(features['label'].int64_list.value)
5555
return (image, label)
5656

57-
dataRDD = images.map(lambda x: toNumpy(str(x[0])))
57+
dataRDD = images.map(lambda x: toNumpy(bytes(x[0])))
5858
else:
5959
if args.format == "csv":
6060
images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')])

examples/mnist/spark/mnist_spark_dataset.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,19 @@
2323
num_ps = 1
2424

2525
parser = argparse.ArgumentParser()
26-
parser.add_argument("-b", "--batch_size", help="number of records per batch", type=int, default=100)
27-
parser.add_argument("-e", "--epochs", help="number of epochs", type=int, default=1)
28-
parser.add_argument("-f", "--format", help="example format: (csv|tfr)", choices=["csv", "tfr"], default="csv")
29-
parser.add_argument("-i", "--images", help="HDFS path to MNIST images in parallelized format")
30-
parser.add_argument("-l", "--labels", help="HDFS path to MNIST labels in parallelized format")
31-
parser.add_argument("-m", "--model", help="HDFS path to save/load model during train/inference", default="mnist_model")
32-
parser.add_argument("-n", "--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors)
33-
parser.add_argument("-o", "--output", help="HDFS path to save test/inference output", default="predictions")
34-
parser.add_argument("-r", "--readers", help="number of reader/enqueue threads", type=int, default=1)
35-
parser.add_argument("-s", "--steps", help="maximum number of steps", type=int, default=1000)
36-
parser.add_argument("-tb", "--tensorboard", help="launch tensorboard process", action="store_true")
37-
parser.add_argument("-X", "--mode", help="train|inference", default="train")
38-
parser.add_argument("-c", "--rdma", help="use rdma connection", default=False)
26+
parser.add_argument("--batch_size", help="number of records per batch", type=int, default=100)
27+
parser.add_argument("--epochs", help="number of epochs", type=int, default=1)
28+
parser.add_argument("--format", help="example format: (csv|tfr)", choices=["csv", "tfr"], default="csv")
29+
parser.add_argument("--images", help="HDFS path to MNIST images in parallelized format")
30+
parser.add_argument("--labels", help="HDFS path to MNIST labels in parallelized format")
31+
parser.add_argument("--model", help="HDFS path to save/load model during train/inference", default="mnist_model")
32+
parser.add_argument("--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors)
33+
parser.add_argument("--output", help="HDFS path to save test/inference output", default="predictions")
34+
parser.add_argument("--readers", help="number of reader/enqueue threads", type=int, default=1)
35+
parser.add_argument("--steps", help="maximum number of steps", type=int, default=1000)
36+
parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")
37+
parser.add_argument("--mode", help="train|inference", default="train")
38+
parser.add_argument("--rdma", help="use rdma connection", default=False)
3939
args = parser.parse_args()
4040
print("args:", args)
4141

@@ -54,7 +54,7 @@ def toNumpy(bytestr):
5454
label = numpy.array(features['label'].int64_list.value)
5555
return (image, label)
5656

57-
dataRDD = images.map(lambda x: toNumpy(str(x[0])))
57+
dataRDD = images.map(lambda x: toNumpy(bytes(x[0])))
5858
else: # args.format == "csv":
5959
images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')])
6060
labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])

examples/mnist/tf/mnist_dist_dataset.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def map_fun(args, ctx):
2222
import tensorflow as tf
2323
import time
2424

25+
num_workers = args.cluster_size if args.driver_ps_nodes else args.cluster_size - args.num_ps
2526
worker_num = ctx.worker_num
2627
job_name = ctx.job_name
2728
task_index = ctx.task_index
@@ -43,10 +44,9 @@ def _parse_csv(ln):
4344
normalized_image = tf.div(image, norm)
4445
label_value = tf.string_to_number(lbl, tf.int32)
4546
label = tf.one_hot(label_value, 10)
46-
return (normalized_image, label, label_value)
47+
return (normalized_image, label)
4748

4849
def _parse_tfr(example_proto):
49-
print("example_proto: {}".format(example_proto))
5050
feature_def = {"label": tf.FixedLenFeature(10, tf.int64),
5151
"image": tf.FixedLenFeature(IMAGE_PIXELS * IMAGE_PIXELS, tf.int64)}
5252
features = tf.parse_single_example(example_proto, feature_def)
@@ -68,10 +68,17 @@ def _parse_tfr(example_proto):
6868
file_pattern = os.path.join(image_dir, 'part-*')
6969
files = tf.gfile.Glob(file_pattern)
7070

71-
parse_fn = _parse_tfr if args.format == 'tfr' else _parse_csv
72-
ds = tf.data.TextLineDataset(files).map(parse_fn).batch(args.batch_size)
71+
if args.format == 'csv2':
72+
ds = tf.data.TextLineDataset(files)
73+
parse_fn = _parse_csv
74+
else: # args.format == 'tfr'
75+
ds = tf.data.TFRecordDataset(files)
76+
parse_fn = _parse_tfr
77+
78+
ds = ds.shard(num_workers, task_index).repeat(args.epochs).shuffle(args.shuffle_size)
79+
ds = ds.map(parse_fn).batch(args.batch_size)
7380
iterator = ds.make_initializable_iterator()
74-
x, y_, y_val = iterator.get_next()
81+
x, y_ = iterator.get_next()
7582

7683
# Variables of the hidden layer
7784
hid_w = tf.Variable(tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, hidden_units],
@@ -156,8 +163,7 @@ def _parse_tfr(example_proto):
156163
if args.mode == "train":
157164
if (step % 100 == 0):
158165
print("{0} step: {1} accuracy: {2}".format(datetime.now().isoformat(), step, sess.run(accuracy)))
159-
_, summary, step, yv = sess.run([train_op, summary_op, global_step, y_val])
160-
# print("yval: {}".format(yv))
166+
_, summary, step = sess.run([train_op, summary_op, global_step])
161167
if sv.is_chief:
162168
summary_writer.add_summary(summary, step)
163169
else: # args.mode == "inference"

examples/mnist/tf/mnist_spark.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,19 @@
2121
num_ps = 1
2222

2323
parser = argparse.ArgumentParser()
24-
parser.add_argument("-e", "--epochs", help="number of epochs", type=int, default=0)
25-
parser.add_argument("-f", "--format", help="example format: (csv|pickle|tfr)", choices=["csv", "pickle", "tfr"], default="tfr")
26-
parser.add_argument("-i", "--images", help="HDFS path to MNIST images in parallelized format")
27-
parser.add_argument("-l", "--labels", help="HDFS path to MNIST labels in parallelized format")
28-
parser.add_argument("-m", "--model", help="HDFS path to save/load model during train/test", default="mnist_model")
29-
parser.add_argument("-n", "--cluster_size", help="number of nodes in the cluster (for Spark Standalone)", type=int, default=num_executors)
30-
parser.add_argument("-o", "--output", help="HDFS path to save test/inference output", default="predictions")
31-
parser.add_argument("-r", "--readers", help="number of reader/enqueue threads", type=int, default=1)
32-
parser.add_argument("-s", "--steps", help="maximum number of steps", type=int, default=1000)
33-
parser.add_argument("-tb", "--tensorboard", help="launch tensorboard process", action="store_true")
34-
parser.add_argument("-X", "--mode", help="train|inference", default="train")
35-
parser.add_argument("-c", "--rdma", help="use rdma connection", default=False)
36-
parser.add_argument("-p", "--driver_ps_nodes", help="run tensorflow PS node on driver locally", default=False)
24+
parser.add_argument("--epochs", help="number of epochs", type=int, default=0)
25+
parser.add_argument("--format", help="example format: (csv|pickle|tfr)", choices=["csv", "pickle", "tfr"], default="tfr")
26+
parser.add_argument("--images", help="HDFS path to MNIST images in parallelized format")
27+
parser.add_argument("--labels", help="HDFS path to MNIST labels in parallelized format")
28+
parser.add_argument("--model", help="HDFS path to save/load model during train/test", default="mnist_model")
29+
parser.add_argument("--cluster_size", help="number of nodes in the cluster (for Spark Standalone)", type=int, default=num_executors)
30+
parser.add_argument("--output", help="HDFS path to save test/inference output", default="predictions")
31+
parser.add_argument("--readers", help="number of reader/enqueue threads", type=int, default=1)
32+
parser.add_argument("--steps", help="maximum number of steps", type=int, default=1000)
33+
parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")
34+
parser.add_argument("--mode", help="train|inference", default="train")
35+
parser.add_argument("--rdma", help="use rdma connection", default=False)
36+
parser.add_argument("--driver_ps_nodes", help="run tensorflow PS node on driver locally", default=False)
3737
args = parser.parse_args()
3838
print("args:", args)
3939

examples/mnist/tf/mnist_spark_dataset.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,27 +21,29 @@
2121
num_ps = 1
2222

2323
parser = argparse.ArgumentParser()
24-
parser.add_argument("-b", "--batch_size", help="number of records per batch", type=int, default=100)
25-
parser.add_argument("-e", "--epochs", help="number of epochs", type=int, default=0)
26-
parser.add_argument("-f", "--format", help="example format: (csv2|tfr)", choices=["csv2", "tfr"], default="tfr")
27-
parser.add_argument("-i", "--images", help="HDFS path to MNIST images in parallelized format")
28-
parser.add_argument("-l", "--labels", help="HDFS path to MNIST labels in parallelized format")
29-
parser.add_argument("-m", "--model", help="HDFS path to save/load model during train/test", default="mnist_model")
30-
parser.add_argument("-n", "--cluster_size", help="number of nodes in the cluster (for Spark Standalone)", type=int, default=num_executors)
31-
parser.add_argument("-o", "--output", help="HDFS path to save test/inference output", default="predictions")
32-
parser.add_argument("-r", "--readers", help="number of reader/enqueue threads", type=int, default=1)
33-
parser.add_argument("-s", "--steps", help="maximum number of steps", type=int, default=1000)
34-
parser.add_argument("-tb", "--tensorboard", help="launch tensorboard process", action="store_true")
35-
parser.add_argument("-X", "--mode", help="train|inference", default="train")
36-
parser.add_argument("-c", "--rdma", help="use rdma connection", default=False)
37-
parser.add_argument("-p", "--driver_ps_nodes", help="""run tensorflow PS node on driver locally.
24+
parser.add_argument("--batch_size", help="number of records per batch", type=int, default=100)
25+
parser.add_argument("--cluster_size", help="number of nodes in the cluster (for Spark Standalone)", type=int, default=num_executors)
26+
parser.add_argument("--driver_ps_nodes", help="""run tensorflow PS node on driver locally.
3827
You will need to set cluster_size = num_executors + num_ps""", default=False)
28+
parser.add_argument("--epochs", help="number of epochs", type=int, default=1)
29+
parser.add_argument("--format", help="example format: (csv2|tfr)", choices=["csv2", "tfr"], default="tfr")
30+
parser.add_argument("--images", help="HDFS path to MNIST images in parallelized format")
31+
parser.add_argument("--labels", help="HDFS path to MNIST labels in parallelized format")
32+
parser.add_argument("--mode", help="train|inference", default="train")
33+
parser.add_argument("--model", help="HDFS path to save/load model during train/test", default="mnist_model")
34+
parser.add_argument("--num_ps", help="number of ps nodes", default=1)
35+
parser.add_argument("--output", help="HDFS path to save test/inference output", default="predictions")
36+
parser.add_argument("--rdma", help="use rdma connection", default=False)
37+
parser.add_argument("--readers", help="number of reader/enqueue threads", type=int, default=1)
38+
parser.add_argument("--shuffle_size", help="size of shuffle buffer", type=int, default=1000)
39+
parser.add_argument("--steps", help="maximum number of steps", type=int, default=1000)
40+
parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")
3941
args = parser.parse_args()
4042
print("args:", args)
4143

4244

4345
print("{0} ===== Start".format(datetime.now().isoformat()))
44-
cluster = TFCluster.run(sc, mnist_dist_dataset.map_fun, args, args.cluster_size, num_ps, args.tensorboard,
46+
cluster = TFCluster.run(sc, mnist_dist_dataset.map_fun, args, args.cluster_size, args.num_ps, args.tensorboard,
4547
TFCluster.InputMode.TENSORFLOW, driver_ps_nodes=args.driver_ps_nodes)
4648
cluster.shutdown()
4749

0 commit comments

Comments
 (0)