Skip to content

Commit bc8bddd

Browse files
authored
Merge pull request #236 from yahoo/leewyang_stream_interval
increase spark streaming interval for examples
2 parents 2289020 + 26115b9 commit bc8bddd

File tree

1 file changed

+16
-20
lines changed

1 file changed

+16
-20
lines changed

examples/mnist/streaming/mnist_spark.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,46 +11,43 @@
1111
from pyspark.streaming import StreamingContext
1212

1313
import argparse
14-
import os
1514
import numpy
16-
import sys
17-
import tensorflow as tf
18-
import threading
19-
import time
2015
from datetime import datetime
2116

2217
from tensorflowonspark import TFCluster
2318
import mnist_dist
2419

2520
sc = SparkContext(conf=SparkConf().setAppName("mnist_streaming"))
26-
ssc = StreamingContext(sc, 10)
21+
ssc = StreamingContext(sc, 60)
2722
executors = sc._conf.get("spark.executor.instances")
2823
num_executors = int(executors) if executors is not None else 1
2924
num_ps = 1
3025

3126
parser = argparse.ArgumentParser()
32-
parser.add_argument("-b", "--batch_size", help="number of records per batch", type=int, default=100)
33-
parser.add_argument("-e", "--epochs", help="number of epochs", type=int, default=1)
34-
parser.add_argument("-f", "--format", help="example format: (csv|csv2|pickle|tfr)", choices=["csv","csv2","pickle","tfr"], default="stream")
35-
parser.add_argument("-i", "--images", help="HDFS path to MNIST images in parallelized format")
36-
parser.add_argument("-m", "--model", help="HDFS path to save/load model during train/inference", default="mnist_model")
37-
parser.add_argument("-n", "--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors)
38-
parser.add_argument("-o", "--output", help="HDFS path to save test/inference output", default="predictions")
39-
parser.add_argument("-s", "--steps", help="maximum number of steps", type=int, default=1000)
40-
parser.add_argument("-tb", "--tensorboard", help="launch tensorboard process", action="store_true")
41-
parser.add_argument("-X", "--mode", help="train|inference", default="train")
42-
parser.add_argument("-c", "--rdma", help="use rdma connection", default=False)
27+
parser.add_argument("--batch_size", help="number of records per batch", type=int, default=100)
28+
parser.add_argument("--epochs", help="number of epochs", type=int, default=1)
29+
parser.add_argument("--format", help="example format: (csv|csv2|pickle|tfr)", choices=["csv", "csv2", "pickle", "tfr"], default="stream")
30+
parser.add_argument("--images", help="HDFS path to MNIST images in parallelized format")
31+
parser.add_argument("--model", help="HDFS path to save/load model during train/inference", default="mnist_model")
32+
parser.add_argument("--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors)
33+
parser.add_argument("--output", help="HDFS path to save test/inference output", default="predictions")
34+
parser.add_argument("--steps", help="maximum number of steps", type=int, default=1000)
35+
parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")
36+
parser.add_argument("--mode", help="train|inference", default="train")
37+
parser.add_argument("--rdma", help="use rdma connection", default=False)
4338
args = parser.parse_args()
44-
print("args:",args)
39+
print("args:", args)
4540

4641
print("{0} ===== Start".format(datetime.now().isoformat()))
4742

43+
4844
def parse(ln):
4945
lbl, img = ln.split('|')
5046
image = [int(x) for x in img.split(',')]
5147
label = numpy.zeros(10)
5248
label[int(lbl)] = 1.0
53-
return (image,label)
49+
return (image, label)
50+
5451

5552
stream = ssc.textFileStream(args.images)
5653
imageRDD = stream.map(lambda ln: parse(ln))
@@ -66,4 +63,3 @@ def parse(ln):
6663
cluster.shutdown(ssc)
6764

6865
print("{0} ===== Stop".format(datetime.now().isoformat()))
69-

0 commit comments

Comments
 (0)