Skip to content

Commit abfd430

Browse files
committed
fix mnist dataset example
1 parent a9d661f commit abfd430

File tree

2 files changed

+30
-22
lines changed

2 files changed

+30
-22
lines changed

examples/mnist/tf/mnist_dist_dataset.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def map_fun(args, ctx):
2222
import tensorflow as tf
2323
import time
2424

25+
num_workers = args.cluster_size if args.driver_ps_nodes else args.cluster_size - args.num_ps
2526
worker_num = ctx.worker_num
2627
job_name = ctx.job_name
2728
task_index = ctx.task_index
@@ -43,10 +44,9 @@ def _parse_csv(ln):
4344
normalized_image = tf.div(image, norm)
4445
label_value = tf.string_to_number(lbl, tf.int32)
4546
label = tf.one_hot(label_value, 10)
46-
return (normalized_image, label, label_value)
47+
return (normalized_image, label)
4748

4849
def _parse_tfr(example_proto):
49-
print("example_proto: {}".format(example_proto))
5050
feature_def = {"label": tf.FixedLenFeature(10, tf.int64),
5151
"image": tf.FixedLenFeature(IMAGE_PIXELS * IMAGE_PIXELS, tf.int64)}
5252
features = tf.parse_single_example(example_proto, feature_def)
@@ -68,10 +68,17 @@ def _parse_tfr(example_proto):
6868
file_pattern = os.path.join(image_dir, 'part-*')
6969
files = tf.gfile.Glob(file_pattern)
7070

71-
parse_fn = _parse_tfr if args.format == 'tfr' else _parse_csv
72-
ds = tf.data.TextLineDataset(files).map(parse_fn).batch(args.batch_size)
71+
if args.format == 'csv2':
72+
ds = tf.data.TextLineDataset(files)
73+
parse_fn = _parse_csv
74+
else: # args.format == 'tfr'
75+
ds = tf.data.TFRecordDataset(files)
76+
parse_fn = _parse_tfr
77+
78+
ds = ds.shard(num_workers, task_index).repeat(args.epochs).shuffle(args.shuffle_size)
79+
ds = ds.map(parse_fn).batch(args.batch_size)
7380
iterator = ds.make_initializable_iterator()
74-
x, y_, y_val = iterator.get_next()
81+
x, y_ = iterator.get_next()
7582

7683
# Variables of the hidden layer
7784
hid_w = tf.Variable(tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, hidden_units],
@@ -156,8 +163,7 @@ def _parse_tfr(example_proto):
156163
if args.mode == "train":
157164
if (step % 100 == 0):
158165
print("{0} step: {1} accuracy: {2}".format(datetime.now().isoformat(), step, sess.run(accuracy)))
159-
_, summary, step, yv = sess.run([train_op, summary_op, global_step, y_val])
160-
# print("yval: {}".format(yv))
166+
_, summary, step = sess.run([train_op, summary_op, global_step])
161167
if sv.is_chief:
162168
summary_writer.add_summary(summary, step)
163169
else: # args.mode == "inference"

examples/mnist/tf/mnist_spark_dataset.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,27 +21,29 @@
2121
num_ps = 1
2222

2323
parser = argparse.ArgumentParser()
24-
parser.add_argument("-b", "--batch_size", help="number of records per batch", type=int, default=100)
25-
parser.add_argument("-e", "--epochs", help="number of epochs", type=int, default=0)
26-
parser.add_argument("-f", "--format", help="example format: (csv2|tfr)", choices=["csv2", "tfr"], default="tfr")
27-
parser.add_argument("-i", "--images", help="HDFS path to MNIST images in parallelized format")
28-
parser.add_argument("-l", "--labels", help="HDFS path to MNIST labels in parallelized format")
29-
parser.add_argument("-m", "--model", help="HDFS path to save/load model during train/test", default="mnist_model")
30-
parser.add_argument("-n", "--cluster_size", help="number of nodes in the cluster (for Spark Standalone)", type=int, default=num_executors)
31-
parser.add_argument("-o", "--output", help="HDFS path to save test/inference output", default="predictions")
32-
parser.add_argument("-r", "--readers", help="number of reader/enqueue threads", type=int, default=1)
33-
parser.add_argument("-s", "--steps", help="maximum number of steps", type=int, default=1000)
34-
parser.add_argument("-tb", "--tensorboard", help="launch tensorboard process", action="store_true")
35-
parser.add_argument("-X", "--mode", help="train|inference", default="train")
36-
parser.add_argument("-c", "--rdma", help="use rdma connection", default=False)
37-
parser.add_argument("-p", "--driver_ps_nodes", help="""run tensorflow PS node on driver locally.
24+
parser.add_argument("--batch_size", help="number of records per batch", type=int, default=100)
25+
parser.add_argument("--cluster_size", help="number of nodes in the cluster (for Spark Standalone)", type=int, default=num_executors)
26+
parser.add_argument("--driver_ps_nodes", help="""run tensorflow PS node on driver locally.
3827
You will need to set cluster_size = num_executors + num_ps""", default=False)
28+
parser.add_argument("--epochs", help="number of epochs", type=int, default=1)
29+
parser.add_argument("--format", help="example format: (csv2|tfr)", choices=["csv2", "tfr"], default="tfr")
30+
parser.add_argument("--images", help="HDFS path to MNIST images in parallelized format")
31+
parser.add_argument("--labels", help="HDFS path to MNIST labels in parallelized format")
32+
parser.add_argument("--mode", help="train|inference", default="train")
33+
parser.add_argument("--model", help="HDFS path to save/load model during train/test", default="mnist_model")
34+
parser.add_argument("--num_ps", help="number of ps nodes", default=1)
35+
parser.add_argument("--output", help="HDFS path to save test/inference output", default="predictions")
36+
parser.add_argument("--rdma", help="use rdma connection", default=False)
37+
parser.add_argument("--readers", help="number of reader/enqueue threads", type=int, default=1)
38+
parser.add_argument("--shuffle_size", help="size of shuffle buffer", type=int, default=1000)
39+
parser.add_argument("--steps", help="maximum number of steps", type=int, default=1000)
40+
parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")
3941
args = parser.parse_args()
4042
print("args:", args)
4143

4244

4345
print("{0} ===== Start".format(datetime.now().isoformat()))
44-
cluster = TFCluster.run(sc, mnist_dist_dataset.map_fun, args, args.cluster_size, num_ps, args.tensorboard,
46+
cluster = TFCluster.run(sc, mnist_dist_dataset.map_fun, args, args.cluster_size, args.num_ps, args.tensorboard,
4547
TFCluster.InputMode.TENSORFLOW, driver_ps_nodes=args.driver_ps_nodes)
4648
cluster.shutdown()
4749

0 commit comments

Comments
 (0)