1111from pyspark .streaming import StreamingContext
1212
1313import argparse
14- import os
1514import numpy
16- import sys
17- import tensorflow as tf
18- import threading
19- import time
2015from datetime import datetime
2116
2217from tensorflowonspark import TFCluster
2318import mnist_dist
2419
2520sc = SparkContext (conf = SparkConf ().setAppName ("mnist_streaming" ))
26- ssc = StreamingContext (sc , 10 )
21+ ssc = StreamingContext (sc , 60 )
2722executors = sc ._conf .get ("spark.executor.instances" )
2823num_executors = int (executors ) if executors is not None else 1
2924num_ps = 1
3025
3126parser = argparse .ArgumentParser ()
32- parser .add_argument ("-b" , "- -batch_size" , help = "number of records per batch" , type = int , default = 100 )
33- parser .add_argument ("-e" , "- -epochs" , help = "number of epochs" , type = int , default = 1 )
34- parser .add_argument ("-f" , "-- format" , help = "example format: (csv|csv2|pickle|tfr)" , choices = ["csv" ,"csv2" ,"pickle" ,"tfr" ], default = "stream" )
35- parser .add_argument ("-i" , "- -images" , help = "HDFS path to MNIST images in parallelized format" )
36- parser .add_argument ("-m" , "- -model" , help = "HDFS path to save/load model during train/inference" , default = "mnist_model" )
37- parser .add_argument ("-n" , "- -cluster_size" , help = "number of nodes in the cluster" , type = int , default = num_executors )
38- parser .add_argument ("-o" , "- -output" , help = "HDFS path to save test/inference output" , default = "predictions" )
39- parser .add_argument ("-s" , "- -steps" , help = "maximum number of steps" , type = int , default = 1000 )
40- parser .add_argument ("-tb" , "- -tensorboard" , help = "launch tensorboard process" , action = "store_true" )
41- parser .add_argument ("-X" , "- -mode" , help = "train|inference" , default = "train" )
42- parser .add_argument ("-c" , "- -rdma" , help = "use rdma connection" , default = False )
27+ parser .add_argument ("--batch_size" , help = "number of records per batch" , type = int , default = 100 )
28+ parser .add_argument ("--epochs" , help = "number of epochs" , type = int , default = 1 )
29+ parser .add_argument ("-- format" , help = "example format: (csv|csv2|pickle|tfr)" , choices = ["csv" , "csv2" , "pickle" , "tfr" ], default = "stream" )
30+ parser .add_argument ("--images" , help = "HDFS path to MNIST images in parallelized format" )
31+ parser .add_argument ("--model" , help = "HDFS path to save/load model during train/inference" , default = "mnist_model" )
32+ parser .add_argument ("--cluster_size" , help = "number of nodes in the cluster" , type = int , default = num_executors )
33+ parser .add_argument ("--output" , help = "HDFS path to save test/inference output" , default = "predictions" )
34+ parser .add_argument ("--steps" , help = "maximum number of steps" , type = int , default = 1000 )
35+ parser .add_argument ("--tensorboard" , help = "launch tensorboard process" , action = "store_true" )
36+ parser .add_argument ("--mode" , help = "train|inference" , default = "train" )
37+ parser .add_argument ("--rdma" , help = "use rdma connection" , default = False )
4338args = parser .parse_args ()
44- print ("args:" ,args )
39+ print ("args:" , args )
4540
4641print ("{0} ===== Start" .format (datetime .now ().isoformat ()))
4742
43+
4844def parse (ln ):
4945 lbl , img = ln .split ('|' )
5046 image = [int (x ) for x in img .split (',' )]
5147 label = numpy .zeros (10 )
5248 label [int (lbl )] = 1.0
53- return (image ,label )
49+ return (image , label )
50+
5451
5552stream = ssc .textFileStream (args .images )
5653imageRDD = stream .map (lambda ln : parse (ln ))
@@ -66,4 +63,3 @@ def parse(ln):
6663cluster .shutdown (ssc )
6764
6865print ("{0} ===== Stop" .format (datetime .now ().isoformat ()))
69-
0 commit comments