Skip to content

Commit 3856c05

Browse files
authored
Merge pull request #307 from yileic/rm_shard
shard() is slow when data is large, thus divide the data before reading
2 parents bc63a3f + ed49d7d commit 3856c05

File tree

5 files changed

+13
-16
lines changed

5 files changed

+13
-16
lines changed

examples/mnist/tf/mnist_dist.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,14 @@ def map_fun(args, ctx):
4141
def read_csv_examples(image_dir, label_dir, batch_size=100, num_epochs=None, task_index=None, num_workers=None):
4242
print_log(worker_num, "num_epochs: {0}".format(num_epochs))
4343
# Setup queue of csv image filenames
44-
tf_record_pattern = os.path.join(image_dir, 'part-*')
45-
images = tf.gfile.Glob(tf_record_pattern)
44+
csv_file_pattern = os.path.join(image_dir, 'part-*')
45+
images = tf.gfile.Glob(csv_file_pattern)
4646
print_log(worker_num, "images: {0}".format(images))
4747
image_queue = tf.train.string_input_producer(images, shuffle=False, capacity=1000, num_epochs=num_epochs, name="image_queue")
4848

4949
# Setup queue of csv label filenames
50-
tf_record_pattern = os.path.join(label_dir, 'part-*')
51-
labels = tf.gfile.Glob(tf_record_pattern)
50+
csv_file_pattern = os.path.join(label_dir, 'part-*')
51+
labels = tf.gfile.Glob(csv_file_pattern)
5252
print_log(worker_num, "labels: {0}".format(labels))
5353
label_queue = tf.train.string_input_producer(labels, shuffle=False, capacity=1000, num_epochs=num_epochs, name="label_queue")
5454

examples/mnist/tf/mnist_dist_dataset.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,15 @@ def _parse_tfr(example_proto):
6666
# Dataset for input data
6767
image_dir = TFNode.hdfs_path(ctx, args.images_labels)
6868
file_pattern = os.path.join(image_dir, 'part-*')
69-
files = tf.gfile.Glob(file_pattern)
7069

70+
ds = tf.data.Dataset.list_files(file_pattern)
71+
ds = ds.shard(num_workers, task_index).repeat(args.epochs).shuffle(args.shuffle_size)
7172
if args.format == 'csv2':
72-
ds = tf.data.TextLineDataset(files)
73+
ds = ds.interleave(tf.data.TextLineDataset, cycle_length=args.readers, block_length=1)
7374
parse_fn = _parse_csv
7475
else: # args.format == 'tfr'
75-
ds = tf.data.TFRecordDataset(files)
76+
ds = ds.interleave(tf.data.TFRecordDataset, cycle_length=args.readers, block_length=1)
7677
parse_fn = _parse_tfr
77-
78-
ds = ds.shard(num_workers, task_index).repeat(args.epochs).shuffle(args.shuffle_size)
7978
ds = ds.map(parse_fn).batch(args.batch_size)
8079
iterator = ds.make_initializable_iterator()
8180
x, y_ = iterator.get_next()
@@ -159,7 +158,6 @@ def _parse_tfr(example_proto):
159158
# See `tf.train.SyncReplicasOptimizer` for additional details on how to
160159
# perform *synchronous* training.
161160

162-
# using QueueRunners/Readers
163161
if args.mode == "train":
164162
if (step % 100 == 0):
165163
print("{0} step: {1} accuracy: {2}".format(datetime.now().isoformat(), step, sess.run(accuracy)))

examples/mnist/tf/mnist_dist_pipeline.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,12 @@ def _parse_tfr(example_proto):
5959
sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
6060
tf.summary.histogram("softmax_weights", sm_w)
6161

62-
# read from saved tf records
62+
# Read from saved tf records
6363
images = TFNode.hdfs_path(ctx, args.tfrecord_dir)
6464
tf_record_pattern = os.path.join(images, 'part-*')
65-
tfr_files = tf.gfile.Glob(tf_record_pattern)
66-
ds = tf.data.TFRecordDataset(tfr_files)
65+
ds = tf.data.Dataset.list_files(tf_record_pattern)
6766
ds = ds.shard(num_workers, task_index).repeat(args.epochs).shuffle(args.shuffle_size)
67+
ds = ds.interleave(tf.data.TFRecordDataset, cycle_length=args.readers, block_length=1)
6868
ds = ds.map(_parse_tfr).batch(args.batch_size)
6969
iterator = ds.make_initializable_iterator()
7070
x, y_ = iterator.get_next()
@@ -122,7 +122,6 @@ def _parse_tfr(example_proto):
122122
# See `tf.train.SyncReplicasOptimizer` for additional details on how to
123123
# perform *synchronous* training.
124124

125-
# using QueueRunners/Readers
126125
if (step % 100 == 0):
127126
print("{0} step: {1} accuracy: {2}".format(datetime.now().isoformat(), step, sess.run(accuracy)))
128127
_, summary, step = sess.run([train_op, summary_op, global_step])

examples/mnist/tf/mnist_spark_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
parser.add_argument("--num_ps", help="number of ps nodes", default=1)
3434
parser.add_argument("--output", help="HDFS path to save test/inference output", default="predictions")
3535
parser.add_argument("--rdma", help="use rdma connection", default=False)
36-
parser.add_argument("--readers", help="number of reader/enqueue threads", type=int, default=1)
36+
parser.add_argument("--readers", help="number of reader/enqueue threads per worker", type=int, default=10)
3737
parser.add_argument("--shuffle_size", help="size of shuffle buffer", type=int, default=1000)
3838
parser.add_argument("--steps", help="maximum number of steps", type=int, default=1000)
3939
parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")

examples/mnist/tf/mnist_spark_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
parser.add_argument("-p", "--driver_ps_nodes", help="""run tensorflow PS node on driver locally.
4040
You will need to set cluster_size = num_executors + num_ps""", default=False)
4141
parser.add_argument("--protocol", help="Tensorflow network protocol (grpc|rdma)", default="grpc")
42-
parser.add_argument("--readers", help="number of reader/enqueue threads", type=int, default=1)
42+
parser.add_argument("--readers", help="number of reader/enqueue threads per worker", type=int, default=10)
4343
parser.add_argument("--steps", help="maximum number of steps", type=int, default=1000)
4444
parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")
4545
parser.add_argument("--shuffle_size", help="size of shuffle buffer", type=int, default=1000)

0 commit comments

Comments
 (0)