|
1 | | -# Copyright 2018 Yahoo Inc. |
| 1 | +# Copyright 2017 Yahoo Inc. |
2 | 2 | # Licensed under the terms of the Apache 2.0 license. |
3 | 3 | # Please see LICENSE file in the project root for terms. |
4 | 4 |
|
5 | 5 | # Distributed MNIST on grid based on TensorFlow MNIST example |
6 | 6 |
|
7 | 7 | from __future__ import absolute_import |
8 | 8 | from __future__ import division |
9 | | -from __future__ import nested_scopes |
10 | 9 | from __future__ import print_function |
11 | 10 |
|
12 | | -from datetime import datetime |
13 | | -import tensorflow as tf |
14 | | -from tensorflowonspark import TFNode |
15 | | - |
16 | 11 |
|
17 | 12 | def print_log(worker_num, arg): |
18 | 13 | print("{0}: {1}".format(worker_num, arg)) |
19 | 14 |
|
20 | 15 |
|
21 | | -class ExportHook(tf.train.SessionRunHook): |
22 | | - def __init__(self, export_dir, input_tensor, output_tensor): |
23 | | - self.export_dir = export_dir |
24 | | - self.input_tensor = input_tensor |
25 | | - self.output_tensor = output_tensor |
26 | | - |
27 | | - def end(self, session): |
28 | | - print("{} ======= Exporting to: {}".format(datetime.now().isoformat(), self.export_dir)) |
29 | | - signatures = { |
30 | | - tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: { |
31 | | - 'inputs': {'image': self.input_tensor}, |
32 | | - 'outputs': {'prediction': self.output_tensor}, |
33 | | - 'method_name': tf.saved_model.signature_constants.PREDICT_METHOD_NAME |
34 | | - } |
35 | | - } |
36 | | - TFNode.export_saved_model(session, |
37 | | - self.export_dir, |
38 | | - tf.saved_model.tag_constants.SERVING, |
39 | | - signatures) |
40 | | - print("{} ======= Done exporting".format(datetime.now().isoformat())) |
41 | | - |
42 | | - |
43 | 16 | def map_fun(args, ctx): |
| 17 | + from datetime import datetime |
44 | 18 | import math |
45 | 19 | import numpy |
| 20 | + import tensorflow as tf |
46 | 21 | import time |
47 | 22 |
|
48 | 23 | worker_num = ctx.worker_num |
49 | 24 | job_name = ctx.job_name |
50 | 25 | task_index = ctx.task_index |
51 | 26 |
|
52 | | - # Delay PS nodes a bit, since workers seem to reserve GPUs more quickly/reliably (w/o conflict) |
53 | | - if job_name == "ps": |
54 | | - time.sleep((worker_num + 1) * 5) |
55 | | - |
56 | 27 | # Parameters |
57 | 28 | IMAGE_PIXELS = 28 |
58 | 29 | hidden_units = 128 |
59 | | - batch_size = args.batch_size |
60 | 30 |
|
61 | 31 | # Get TF cluster and server instances |
62 | 32 | cluster, server = ctx.start_cluster_server(1, args.rdma) |
63 | 33 |
|
64 | | - def feed_dict(batch): |
65 | | - # Convert from [(images, labels)] to two numpy arrays of the proper type |
66 | | - images = [] |
67 | | - labels = [] |
68 | | - for item in batch: |
69 | | - images.append(item[0]) |
70 | | - labels.append(item[1]) |
71 | | - xs = numpy.array(images) |
72 | | - xs = xs.astype(numpy.float32) |
73 | | - xs = xs / 255.0 |
74 | | - ys = numpy.array(labels) |
75 | | - ys = ys.astype(numpy.uint8) |
76 | | - return (xs, ys) |
| 34 | + # Create generator for Spark data feed |
| 35 | + tf_feed = ctx.get_data_feed(args.mode == 'train') |
| 36 | + |
| 37 | + def rdd_generator(): |
| 38 | + while not tf_feed.should_stop(): |
| 39 | + batch = tf_feed.next_batch(1) |
| 40 | + if len(batch) == 0: |
| 41 | + return |
| 42 | + row = batch[0] |
| 43 | + image = numpy.array(row[0]).astype(numpy.float32) / 255.0 |
| 44 | + label = numpy.array(row[1]).astype(numpy.int64) |
| 45 | + yield (image, label) |
77 | 46 |
|
78 | 47 | if job_name == "ps": |
79 | 48 | server.join() |
80 | 49 | elif job_name == "worker": |
81 | | - |
82 | 50 | # Assigns ops to the local worker by default. |
83 | 51 | with tf.device(tf.train.replica_device_setter( |
84 | 52 | worker_device="/job:worker/task:%d" % task_index, |
85 | 53 | cluster=cluster)): |
86 | 54 |
|
87 | | - # Placeholders or QueueRunner/Readers for input data |
88 | | - with tf.name_scope('inputs'): |
89 | | - x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS], name="x") |
90 | | - y_ = tf.placeholder(tf.float32, [None, 10], name="y_") |
91 | | - |
92 | | - x_img = tf.reshape(x, [-1, IMAGE_PIXELS, IMAGE_PIXELS, 1]) |
93 | | - tf.summary.image("x_img", x_img) |
94 | | - |
95 | | - with tf.name_scope('layer'): |
96 | | - # Variables of the hidden layer |
97 | | - with tf.name_scope('hidden_layer'): |
98 | | - hid_w = tf.Variable(tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, hidden_units], stddev=1.0 / IMAGE_PIXELS), name="hid_w") |
99 | | - hid_b = tf.Variable(tf.zeros([hidden_units]), name="hid_b") |
100 | | - tf.summary.histogram("hidden_weights", hid_w) |
101 | | - hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b) |
102 | | - hid = tf.nn.relu(hid_lin) |
103 | | - |
104 | | - # Variables of the softmax layer |
105 | | - with tf.name_scope('softmax_layer'): |
106 | | - sm_w = tf.Variable(tf.truncated_normal([hidden_units, 10], stddev=1.0 / math.sqrt(hidden_units)), name="sm_w") |
107 | | - sm_b = tf.Variable(tf.zeros([10]), name="sm_b") |
108 | | - tf.summary.histogram("softmax_weights", sm_w) |
109 | | - y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b)) |
| 55 | + # Dataset for input data |
| 56 | + ds = tf.data.Dataset.from_generator(rdd_generator, (tf.float32, tf.float32), (tf.TensorShape([IMAGE_PIXELS * IMAGE_PIXELS]), tf.TensorShape([10]))).batch(args.batch_size) |
| 57 | + iterator = ds.make_one_shot_iterator() |
| 58 | + x, y_ = iterator.get_next() |
110 | 59 |
|
111 | | - global_step = tf.train.get_or_create_global_step() |
| 60 | + # Variables of the hidden layer |
| 61 | + hid_w = tf.Variable(tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, hidden_units], |
| 62 | + stddev=1.0 / IMAGE_PIXELS), name="hid_w") |
| 63 | + hid_b = tf.Variable(tf.zeros([hidden_units]), name="hid_b") |
| 64 | + tf.summary.histogram("hidden_weights", hid_w) |
| 65 | + |
| 66 | + # Variables of the softmax layer |
| 67 | + sm_w = tf.Variable(tf.truncated_normal([hidden_units, 10], |
| 68 | + stddev=1.0 / math.sqrt(hidden_units)), name="sm_w") |
| 69 | + sm_b = tf.Variable(tf.zeros([10]), name="sm_b") |
| 70 | + tf.summary.histogram("softmax_weights", sm_w) |
| 71 | + |
| 72 | + x_img = tf.reshape(x, [-1, IMAGE_PIXELS, IMAGE_PIXELS, 1]) |
| 73 | + tf.summary.image("x_img", x_img) |
| 74 | + |
| 75 | + hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b) |
| 76 | + hid = tf.nn.relu(hid_lin) |
112 | 77 |
|
113 | | - with tf.name_scope('loss'): |
114 | | - loss = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0))) |
115 | | - tf.summary.scalar("loss", loss) |
| 78 | + y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b)) |
116 | 79 |
|
117 | | - with tf.name_scope('train'): |
118 | | - train_op = tf.train.AdagradOptimizer(0.01).minimize(loss, global_step=global_step) |
| 80 | + global_step = tf.train.get_or_create_global_step() |
| 81 | + |
| 82 | + loss = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0))) |
| 83 | + tf.summary.scalar("loss", loss) |
| 84 | + train_op = tf.train.AdagradOptimizer(0.01).minimize( |
| 85 | + loss, global_step=global_step) |
119 | 86 |
|
120 | 87 | # Test trained model |
121 | 88 | label = tf.argmax(y_, 1, name="label") |
122 | 89 | prediction = tf.argmax(y, 1, name="prediction") |
123 | 90 | correct_prediction = tf.equal(prediction, label) |
124 | | - |
125 | 91 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="accuracy") |
126 | 92 | tf.summary.scalar("acc", accuracy) |
127 | 93 |
|
| 94 | + saver = tf.train.Saver() |
128 | 95 | summary_op = tf.summary.merge_all() |
| 96 | + init_op = tf.global_variables_initializer() |
129 | 97 |
|
| 98 | + # Create a "supervisor", which oversees the training process and stores model state into HDFS |
130 | 99 | logdir = ctx.absolute_path(args.model) |
131 | 100 | print("tensorflow model path: {0}".format(logdir)) |
| 101 | + summary_writer = tf.summary.FileWriter("tensorboard_%d" % worker_num, graph=tf.get_default_graph()) |
132 | 102 |
|
133 | | - if job_name == "worker" and task_index == 0: |
134 | | - summary_writer = tf.summary.FileWriter(logdir, graph=tf.get_default_graph()) |
135 | | - |
136 | | - # The MonitoredTrainingSession takes care of session initialization, restoring from |
137 | | - # a checkpoint, and closing when done or an error occurs |
138 | 103 | with tf.train.MonitoredTrainingSession(master=server.target, |
139 | | - is_chief=(task_index == 0), |
140 | | - checkpoint_dir=logdir, |
141 | | - save_checkpoint_secs=10, |
142 | | - hooks=[tf.train.StopAtStepHook(last_step=args.steps)], |
143 | | - chief_only_hooks=[ExportHook(ctx.absolute_path(args.export_dir), x, prediction)]) as mon_sess: |
| 104 | + is_chief=(task_index == 0), |
| 105 | + scaffold=tf.train.Scaffold(init_op=init_op, summary_op=summary_op, saver=saver), |
| 106 | + checkpoint_dir=logdir, |
| 107 | + hooks=[tf.train.StopAtStepHook(last_step=args.steps)]) as sess: |
| 108 | + print("{} session ready".format(datetime.now().isoformat())) |
| 109 | + |
| 110 | + # Loop until the session shuts down or feed has no more data |
144 | 111 | step = 0 |
145 | | - tf_feed = ctx.get_data_feed(args.mode == "train") |
146 | | - while not mon_sess.should_stop() and not tf_feed.should_stop(): |
147 | | - # Run a training step asynchronously |
| 112 | + while not sess.should_stop() and not tf_feed.should_stop(): |
| 113 | + # Run a training step asynchronously. |
148 | 114 | # See `tf.train.SyncReplicasOptimizer` for additional details on how to |
149 | 115 | # perform *synchronous* training. |
150 | 116 |
|
151 | | - # using feed_dict |
152 | | - batch_xs, batch_ys = feed_dict(tf_feed.next_batch(batch_size)) |
153 | | - feed = {x: batch_xs, y_: batch_ys} |
154 | | - |
155 | | - if len(batch_xs) > 0: |
156 | | - if args.mode == "train": |
157 | | - _, summary, step = mon_sess.run([train_op, summary_op, global_step], feed_dict=feed) |
158 | | - # print accuracy and save model checkpoint to HDFS every 100 steps |
159 | | - if (step % 100 == 0): |
160 | | - print("{0} step: {1} accuracy: {2}".format(datetime.now().isoformat(), step, mon_sess.run(accuracy, {x: batch_xs, y_: batch_ys}))) |
161 | | - |
162 | | - if task_index == 0: |
163 | | - summary_writer.add_summary(summary, step) |
164 | | - else: # args.mode == "inference" |
165 | | - labels, preds, acc = mon_sess.run([label, prediction, accuracy], feed_dict=feed) |
166 | | - |
167 | | - results = ["{0} Label: {1}, Prediction: {2}".format(datetime.now().isoformat(), l, p) for l, p in zip(labels, preds)] |
168 | | - tf_feed.batch_results(results) |
169 | | - print("results: {0}, acc: {1}".format(results, acc)) |
170 | | - |
171 | | - if mon_sess.should_stop() or step >= args.steps: |
172 | | - tf_feed.terminate() |
173 | | - |
174 | | - # Ask for all the services to stop. |
175 | | - print("{0} stopping MonitoredTrainingSession".format(datetime.now().isoformat())) |
176 | | - |
177 | | - if job_name == "worker" and task_index == 0: |
178 | | - summary_writer.close() |
| 117 | + if args.mode == "train": |
| 118 | + _, summary, step = sess.run([train_op, summary_op, global_step]) |
| 119 | + if (step % 100 == 0): |
| 120 | + print("{} step: {} accuracy: {}".format(datetime.now().isoformat(), step, sess.run(accuracy))) |
| 121 | + if task_index == 0: |
| 122 | + summary_writer.add_summary(summary, step) |
| 123 | + else: # args.mode == "inference" |
| 124 | + labels, preds, acc = sess.run([label, prediction, accuracy]) |
| 125 | + results = ["{} Label: {}, Prediction: {}".format(datetime.now().isoformat(), l, p) for l, p in zip(labels, preds)] |
| 126 | + tf_feed.batch_results(results) |
| 127 | + print("acc: {}".format(acc)) |
| 128 | + |
| 129 | + print("{} stopping MonitoredTrainingSession".format(datetime.now().isoformat())) |
| 130 | + |
| 131 | + # WORKAROUND FOR https://github.com/tensorflow/tensorflow/issues/21745 |
| 132 | + # wait for all other nodes to complete (via done files) |
| 133 | + done_dir = "{}/{}/done".format(ctx.absolute_path(args.model), args.mode) |
| 134 | + print("Writing done file to: {}".format(done_dir)) |
| 135 | + tf.gfile.MakeDirs(done_dir) |
| 136 | + with tf.gfile.GFile("{}/{}".format(done_dir, ctx.task_index), 'w') as done_file: |
| 137 | + done_file.write("done") |
| 138 | + |
| 139 | + for i in range(60): |
| 140 | + if len(tf.gfile.ListDirectory(done_dir)) < len(ctx.cluster_spec['worker']): |
| 141 | + print("{} Waiting for other nodes {}".format(datetime.now().isoformat(), i)) |
| 142 | + time.sleep(1) |
| 143 | + else: |
| 144 | + print("{} All nodes done".format(datetime.now().isoformat())) |
| 145 | + break |
0 commit comments