Skip to content

Commit edacbb4

Browse files
committed
more pep8
1 parent abfd430 commit edacbb4

File tree

3 files changed

+47
-42
lines changed

3 files changed

+47
-42
lines changed

examples/mnist/mnist_data_setup.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,35 +8,39 @@
88

99
import numpy
1010
import tensorflow as tf
11-
from array import array
1211
from tensorflow.contrib.learn.python.learn.datasets import mnist
1312

13+
1414
def toTFExample(image, label):
1515
"""Serializes an image/label as a TFExample byte string"""
1616
example = tf.train.Example(
17-
features = tf.train.Features(
18-
feature = {
17+
features=tf.train.Features(
18+
feature={
1919
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=label.astype("int64"))),
2020
'image': tf.train.Feature(int64_list=tf.train.Int64List(value=image.astype("int64")))
2121
}
2222
)
2323
)
2424
return example.SerializeToString()
2525

26+
2627
def fromTFExample(bytestr):
2728
"""Deserializes a TFExample from a byte string"""
2829
example = tf.train.Example()
2930
example.ParseFromString(bytestr)
3031
return example
3132

33+
3234
def toCSV(vec):
3335
"""Converts a vector/array into a CSV string"""
3436
return ','.join([str(i) for i in vec])
3537

38+
3639
def fromCSV(s):
3740
"""Converts a CSV string to a vector/array"""
3841
return [float(x) for x in s.split(',') if len(s) > 0]
3942

43+
4044
def writeMNIST(sc, input_images, input_labels, output, format, num_partitions):
4145
"""Writes MNIST image/label vectors into parallelized files on HDFS"""
4246
# load MNIST gzip into memory
@@ -69,12 +73,12 @@ def writeMNIST(sc, input_images, input_labels, output, format, num_partitions):
6973
labelRDD.map(toCSV).saveAsTextFile(output_labels)
7074
elif format == "csv2":
7175
imageRDD.map(toCSV).zip(labelRDD).map(lambda x: str(x[1]) + "|" + x[0]).saveAsTextFile(output)
72-
else: # format == "tfr":
76+
else: # format == "tfr":
7377
tfRDD = imageRDD.zip(labelRDD).map(lambda x: (bytearray(toTFExample(x[0], x[1])), None))
7478
# requires: --jars tensorflow-hadoop-1.0-SNAPSHOT.jar
7579
tfRDD.saveAsNewAPIHadoopFile(output, "org.tensorflow.hadoop.io.TFRecordFileOutputFormat",
76-
keyClass="org.apache.hadoop.io.BytesWritable",
77-
valueClass="org.apache.hadoop.io.NullWritable")
80+
keyClass="org.apache.hadoop.io.BytesWritable",
81+
valueClass="org.apache.hadoop.io.NullWritable")
7882
# Note: this creates TFRecord files w/o requiring a custom Input/Output format
7983
# else: # format == "tfr":
8084
# def writeTFRecords(index, iter):
@@ -86,6 +90,7 @@ def writeMNIST(sc, input_images, input_labels, output, format, num_partitions):
8690
# tfRDD = imageRDD.zip(labelRDD).map(lambda x: toTFExample(x[0], x[1]))
8791
# tfRDD.mapPartitionsWithIndex(writeTFRecords).collect()
8892

93+
8994
def readMNIST(sc, output, format):
9095
"""Reads/verifies previously created output"""
9196

@@ -100,11 +105,11 @@ def readMNIST(sc, output, format):
100105
elif format == "csv":
101106
imageRDD = sc.textFile(output_images).map(fromCSV)
102107
labelRDD = sc.textFile(output_labels).map(fromCSV)
103-
else: # format.startswith("tf"):
108+
else: # format.startswith("tf"):
104109
# requires: --jars tensorflow-hadoop-1.0-SNAPSHOT.jar
105110
tfRDD = sc.newAPIHadoopFile(output, "org.tensorflow.hadoop.io.TFRecordFileInputFormat",
106-
keyClass="org.apache.hadoop.io.BytesWritable",
107-
valueClass="org.apache.hadoop.io.NullWritable")
111+
keyClass="org.apache.hadoop.io.BytesWritable",
112+
valueClass="org.apache.hadoop.io.NullWritable")
108113
imageRDD = tfRDD.map(lambda x: fromTFExample(str(x[0])))
109114

110115
num_images = imageRDD.count()
@@ -114,21 +119,22 @@ def readMNIST(sc, output, format):
114119
print("num_labels: ", num_labels)
115120
print("samples: ", samples)
116121

122+
117123
if __name__ == "__main__":
118124
import argparse
119125

120126
from pyspark.context import SparkContext
121127
from pyspark.conf import SparkConf
122128

123129
parser = argparse.ArgumentParser()
124-
parser.add_argument("-f", "--format", help="output format", choices=["csv","csv2","pickle","tf","tfr"], default="csv")
125-
parser.add_argument("-n", "--num-partitions", help="Number of output partitions", type=int, default=10)
126-
parser.add_argument("-o", "--output", help="HDFS directory to save examples in parallelized format", default="mnist_data")
127-
parser.add_argument("-r", "--read", help="read previously saved examples", action="store_true")
128-
parser.add_argument("-v", "--verify", help="verify saved examples after writing", action="store_true")
130+
parser.add_argument("--format", help="output format", choices=["csv", "csv2", "pickle", "tf", "tfr"], default="csv")
131+
parser.add_argument("--num-partitions", help="Number of output partitions", type=int, default=10)
132+
parser.add_argument("--output", help="HDFS directory to save examples in parallelized format", default="mnist_data")
133+
parser.add_argument("--read", help="read previously saved examples", action="store_true")
134+
parser.add_argument("--verify", help="verify saved examples after writing", action="store_true")
129135

130136
args = parser.parse_args()
131-
print("args:",args)
137+
print("args:", args)
132138

133139
sc = SparkContext(conf=SparkConf().setAppName("mnist_parallelize"))
134140

@@ -139,4 +145,3 @@ def readMNIST(sc, output, format):
139145

140146
if args.read or args.verify:
141147
readMNIST(sc, args.output + "/train", args.format)
142-

examples/mnist/spark/mnist_spark_dataset.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,19 @@
2323
num_ps = 1
2424

2525
parser = argparse.ArgumentParser()
26-
parser.add_argument("-b", "--batch_size", help="number of records per batch", type=int, default=100)
27-
parser.add_argument("-e", "--epochs", help="number of epochs", type=int, default=1)
28-
parser.add_argument("-f", "--format", help="example format: (csv|tfr)", choices=["csv", "tfr"], default="csv")
29-
parser.add_argument("-i", "--images", help="HDFS path to MNIST images in parallelized format")
30-
parser.add_argument("-l", "--labels", help="HDFS path to MNIST labels in parallelized format")
31-
parser.add_argument("-m", "--model", help="HDFS path to save/load model during train/inference", default="mnist_model")
32-
parser.add_argument("-n", "--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors)
33-
parser.add_argument("-o", "--output", help="HDFS path to save test/inference output", default="predictions")
34-
parser.add_argument("-r", "--readers", help="number of reader/enqueue threads", type=int, default=1)
35-
parser.add_argument("-s", "--steps", help="maximum number of steps", type=int, default=1000)
36-
parser.add_argument("-tb", "--tensorboard", help="launch tensorboard process", action="store_true")
37-
parser.add_argument("-X", "--mode", help="train|inference", default="train")
38-
parser.add_argument("-c", "--rdma", help="use rdma connection", default=False)
26+
parser.add_argument("--batch_size", help="number of records per batch", type=int, default=100)
27+
parser.add_argument("--epochs", help="number of epochs", type=int, default=1)
28+
parser.add_argument("--format", help="example format: (csv|tfr)", choices=["csv", "tfr"], default="csv")
29+
parser.add_argument("--images", help="HDFS path to MNIST images in parallelized format")
30+
parser.add_argument("--labels", help="HDFS path to MNIST labels in parallelized format")
31+
parser.add_argument("--model", help="HDFS path to save/load model during train/inference", default="mnist_model")
32+
parser.add_argument("--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors)
33+
parser.add_argument("--output", help="HDFS path to save test/inference output", default="predictions")
34+
parser.add_argument("--readers", help="number of reader/enqueue threads", type=int, default=1)
35+
parser.add_argument("--steps", help="maximum number of steps", type=int, default=1000)
36+
parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")
37+
parser.add_argument("--mode", help="train|inference", default="train")
38+
parser.add_argument("--rdma", help="use rdma connection", default=False)
3939
args = parser.parse_args()
4040
print("args:", args)
4141

examples/mnist/tf/mnist_spark.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,19 @@
2121
num_ps = 1
2222

2323
parser = argparse.ArgumentParser()
24-
parser.add_argument("-e", "--epochs", help="number of epochs", type=int, default=0)
25-
parser.add_argument("-f", "--format", help="example format: (csv|pickle|tfr)", choices=["csv", "pickle", "tfr"], default="tfr")
26-
parser.add_argument("-i", "--images", help="HDFS path to MNIST images in parallelized format")
27-
parser.add_argument("-l", "--labels", help="HDFS path to MNIST labels in parallelized format")
28-
parser.add_argument("-m", "--model", help="HDFS path to save/load model during train/test", default="mnist_model")
29-
parser.add_argument("-n", "--cluster_size", help="number of nodes in the cluster (for Spark Standalone)", type=int, default=num_executors)
30-
parser.add_argument("-o", "--output", help="HDFS path to save test/inference output", default="predictions")
31-
parser.add_argument("-r", "--readers", help="number of reader/enqueue threads", type=int, default=1)
32-
parser.add_argument("-s", "--steps", help="maximum number of steps", type=int, default=1000)
33-
parser.add_argument("-tb", "--tensorboard", help="launch tensorboard process", action="store_true")
34-
parser.add_argument("-X", "--mode", help="train|inference", default="train")
35-
parser.add_argument("-c", "--rdma", help="use rdma connection", default=False)
36-
parser.add_argument("-p", "--driver_ps_nodes", help="run tensorflow PS node on driver locally", default=False)
24+
parser.add_argument("--epochs", help="number of epochs", type=int, default=0)
25+
parser.add_argument("--format", help="example format: (csv|pickle|tfr)", choices=["csv", "pickle", "tfr"], default="tfr")
26+
parser.add_argument("--images", help="HDFS path to MNIST images in parallelized format")
27+
parser.add_argument("--labels", help="HDFS path to MNIST labels in parallelized format")
28+
parser.add_argument("--model", help="HDFS path to save/load model during train/test", default="mnist_model")
29+
parser.add_argument("--cluster_size", help="number of nodes in the cluster (for Spark Standalone)", type=int, default=num_executors)
30+
parser.add_argument("--output", help="HDFS path to save test/inference output", default="predictions")
31+
parser.add_argument("--readers", help="number of reader/enqueue threads", type=int, default=1)
32+
parser.add_argument("--steps", help="maximum number of steps", type=int, default=1000)
33+
parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")
34+
parser.add_argument("--mode", help="train|inference", default="train")
35+
parser.add_argument("--rdma", help="use rdma connection", default=False)
36+
parser.add_argument("--driver_ps_nodes", help="run tensorflow PS node on driver locally", default=False)
3737
args = parser.parse_args()
3838
print("args:", args)
3939

0 commit comments

Comments
 (0)