@@ -59,12 +59,12 @@ def _parse_tfr(example_proto):
5959 sm_b = tf .Variable (tf .zeros ([10 ]), name = "sm_b" )
6060 tf .summary .histogram ("softmax_weights" , sm_w )
6161
62- # read from saved tf records
62+ # Read from saved tf records
6363 images = TFNode .hdfs_path (ctx , args .tfrecord_dir )
6464 tf_record_pattern = os .path .join (images , 'part-*' )
65- tfr_files = tf .gfile .Glob (tf_record_pattern )
66- ds = tf .data .TFRecordDataset (tfr_files )
65+ ds = tf .data .Dataset .list_files (tf_record_pattern )
6766 ds = ds .shard (num_workers , task_index ).repeat (args .epochs ).shuffle (args .shuffle_size )
67+ ds = ds .interleave (tf .data .TFRecordDataset , cycle_length = args .readers , block_length = 1 )
6868 ds = ds .map (_parse_tfr ).batch (args .batch_size )
6969 iterator = ds .make_initializable_iterator ()
7070 x , y_ = iterator .get_next ()
@@ -122,7 +122,6 @@ def _parse_tfr(example_proto):
122122 # See `tf.train.SyncReplicasOptimizer` for additional details on how to
123123 # perform *synchronous* training.
124124
125- # using QueueRunners/Readers
126125 if (step % 100 == 0 ):
127126 print ("{0} step: {1} accuracy: {2}" .format (datetime .now ().isoformat (), step , sess .run (accuracy )))
128127 _ , summary , step = sess .run ([train_op , summary_op , global_step ])
0 commit comments