88
99import numpy
1010import tensorflow as tf
11- from array import array
1211from tensorflow .contrib .learn .python .learn .datasets import mnist
1312
13+
1414def toTFExample (image , label ):
1515 """Serializes an image/label as a TFExample byte string"""
1616 example = tf .train .Example (
17- features = tf .train .Features (
18- feature = {
17+ features = tf .train .Features (
18+ feature = {
1919 'label' : tf .train .Feature (int64_list = tf .train .Int64List (value = label .astype ("int64" ))),
2020 'image' : tf .train .Feature (int64_list = tf .train .Int64List (value = image .astype ("int64" )))
2121 }
2222 )
2323 )
2424 return example .SerializeToString ()
2525
26+
2627def fromTFExample (bytestr ):
2728 """Deserializes a TFExample from a byte string"""
2829 example = tf .train .Example ()
2930 example .ParseFromString (bytestr )
3031 return example
3132
33+
3234def toCSV (vec ):
3335 """Converts a vector/array into a CSV string"""
3436 return ',' .join ([str (i ) for i in vec ])
3537
38+
3639def fromCSV (s ):
3740 """Converts a CSV string to a vector/array"""
3841 return [float (x ) for x in s .split (',' ) if len (s ) > 0 ]
3942
43+
4044def writeMNIST (sc , input_images , input_labels , output , format , num_partitions ):
4145 """Writes MNIST image/label vectors into parallelized files on HDFS"""
4246 # load MNIST gzip into memory
@@ -69,12 +73,12 @@ def writeMNIST(sc, input_images, input_labels, output, format, num_partitions):
6973 labelRDD .map (toCSV ).saveAsTextFile (output_labels )
7074 elif format == "csv2" :
7175 imageRDD .map (toCSV ).zip (labelRDD ).map (lambda x : str (x [1 ]) + "|" + x [0 ]).saveAsTextFile (output )
72- else : # format == "tfr":
76+ else : # format == "tfr":
7377 tfRDD = imageRDD .zip (labelRDD ).map (lambda x : (bytearray (toTFExample (x [0 ], x [1 ])), None ))
7478 # requires: --jars tensorflow-hadoop-1.0-SNAPSHOT.jar
7579 tfRDD .saveAsNewAPIHadoopFile (output , "org.tensorflow.hadoop.io.TFRecordFileOutputFormat" ,
76- keyClass = "org.apache.hadoop.io.BytesWritable" ,
77- valueClass = "org.apache.hadoop.io.NullWritable" )
80+ keyClass = "org.apache.hadoop.io.BytesWritable" ,
81+ valueClass = "org.apache.hadoop.io.NullWritable" )
7882# Note: this creates TFRecord files w/o requiring a custom Input/Output format
7983# else: # format == "tfr":
8084# def writeTFRecords(index, iter):
@@ -86,6 +90,7 @@ def writeMNIST(sc, input_images, input_labels, output, format, num_partitions):
8690# tfRDD = imageRDD.zip(labelRDD).map(lambda x: toTFExample(x[0], x[1]))
8791# tfRDD.mapPartitionsWithIndex(writeTFRecords).collect()
8892
93+
8994def readMNIST (sc , output , format ):
9095 """Reads/verifies previously created output"""
9196
@@ -100,12 +105,12 @@ def readMNIST(sc, output, format):
100105 elif format == "csv" :
101106 imageRDD = sc .textFile (output_images ).map (fromCSV )
102107 labelRDD = sc .textFile (output_labels ).map (fromCSV )
103- else : # format.startswith("tf"):
108+ else : # format.startswith("tf"):
104109 # requires: --jars tensorflow-hadoop-1.0-SNAPSHOT.jar
105110 tfRDD = sc .newAPIHadoopFile (output , "org.tensorflow.hadoop.io.TFRecordFileInputFormat" ,
106- keyClass = "org.apache.hadoop.io.BytesWritable" ,
107- valueClass = "org.apache.hadoop.io.NullWritable" )
108- imageRDD = tfRDD .map (lambda x : fromTFExample (str (x [0 ])))
111+ keyClass = "org.apache.hadoop.io.BytesWritable" ,
112+ valueClass = "org.apache.hadoop.io.NullWritable" )
113+ imageRDD = tfRDD .map (lambda x : fromTFExample (bytes (x [0 ])))
109114
110115 num_images = imageRDD .count ()
111116 num_labels = labelRDD .count () if labelRDD is not None else num_images
@@ -114,21 +119,22 @@ def readMNIST(sc, output, format):
114119 print ("num_labels: " , num_labels )
115120 print ("samples: " , samples )
116121
122+
117123if __name__ == "__main__" :
118124 import argparse
119125
120126 from pyspark .context import SparkContext
121127 from pyspark .conf import SparkConf
122128
123129 parser = argparse .ArgumentParser ()
124- parser .add_argument ("-f" , "-- format" , help = "output format" , choices = ["csv" ,"csv2" ,"pickle" ,"tf" ,"tfr" ], default = "csv" )
125- parser .add_argument ("-n" , "- -num-partitions" , help = "Number of output partitions" , type = int , default = 10 )
126- parser .add_argument ("-o" , "- -output" , help = "HDFS directory to save examples in parallelized format" , default = "mnist_data" )
127- parser .add_argument ("-r" , "- -read" , help = "read previously saved examples" , action = "store_true" )
128- parser .add_argument ("-v" , "- -verify" , help = "verify saved examples after writing" , action = "store_true" )
130+ parser .add_argument ("-- format" , help = "output format" , choices = ["csv" , "csv2" , "pickle" , "tf" , "tfr" ], default = "csv" )
131+ parser .add_argument ("--num-partitions" , help = "Number of output partitions" , type = int , default = 10 )
132+ parser .add_argument ("--output" , help = "HDFS directory to save examples in parallelized format" , default = "mnist_data" )
133+ parser .add_argument ("--read" , help = "read previously saved examples" , action = "store_true" )
134+ parser .add_argument ("--verify" , help = "verify saved examples after writing" , action = "store_true" )
129135
130136 args = parser .parse_args ()
131- print ("args:" ,args )
137+ print ("args:" , args )
132138
133139 sc = SparkContext (conf = SparkConf ().setAppName ("mnist_parallelize" ))
134140
@@ -139,4 +145,3 @@ def readMNIST(sc, output, format):
139145
140146 if args .read or args .verify :
141147 readMNIST (sc , args .output + "/train" , args .format )
142-
0 commit comments