11
11
# See the License for the specific language governing permissions and
12
12
# limitations under the License.
13
13
14
- import horovod .tensorflow as hvd
14
+ import argparse
15
+ from contextlib import closing
16
+
17
+ import recordio
15
18
import tensorflow as tf
16
19
17
- from elasticdl .python .common .constants import Mode
20
+ from elasticai_api .tensorflow .controller import create_elastic_controller
21
+ from elasticai_api .tensorflow .optimizer import (
22
+ AdjustBackwardPassesPerStepHook ,
23
+ DistributedOptimizer ,
24
+ )
18
25
from elasticdl .python .common .log_utils import default_logger as logger
19
26
27
+ layers = tf .layers
20
28
21
- def train (dataset , elastic_controller ):
22
- dataset_it = dataset .make_one_shot_iterator ()
23
- batch_x , batch_y = dataset_it .get_next ()
24
- batch_x = tf .cast (batch_x , tf .float32 )
25
29
26
- x = tf .keras .layers .Reshape ((28 , 28 , 1 ))(batch_x )
27
- x = tf .keras .layers .Conv2D (32 , kernel_size = (3 , 3 ), activation = "relu" )(x )
28
- x = tf .keras .layers .Conv2D (64 , kernel_size = (3 , 3 ), activation = "relu" )(x )
29
- x = tf .keras .layers .BatchNormalization ()(x )
30
- x = tf .keras .layers .MaxPooling2D (pool_size = (2 , 2 ))(x )
31
- x = tf .keras .layers .Dropout (0.25 )(x )
32
- x = tf .keras .layers .Flatten ()(x )
33
- outputs = tf .keras .layers .Dense (10 )(x )
34
- loss = tf .reduce_mean (
35
- input_tensor = tf .nn .sparse_softmax_cross_entropy_with_logits (
36
- logits = outputs , labels = tf .reshape (batch_y , [- 1 ])
30
+ def get_dataset_gen (data_shard_service ):
31
+ def gen ():
32
+ while True :
33
+ shard = data_shard_service .fetch_shard ()
34
+ if not shard :
35
+ raise StopIteration ("No data" )
36
+ with closing (
37
+ recordio .Scanner (
38
+ shard .name , shard .start , shard .end - shard .start ,
39
+ )
40
+ ) as reader :
41
+ for i in range (shard .start , shard .end ):
42
+ record = reader .record ()
43
+ if record :
44
+ yield record
45
+
46
+ return gen
47
+
48
+
49
+ def create_dataset (data_shard_service ):
50
+ gen = get_dataset_gen (data_shard_service )
51
+ dataset = tf .data .Dataset .from_generator (gen , tf .string )
52
+ return dataset
53
+
54
+
55
+ def conv_model (feature , target , mode ):
56
+ """2-layer convolution model."""
57
+ # Convert the target to a one-hot tensor of shape (batch_size, 10) and
58
+ # with a on-value of 1 for each one-hot vector of length 10.
59
+ target = tf .one_hot (tf .cast (target , tf .int32 ), 10 , 1 , 0 )
60
+
61
+ # Reshape feature to 4d tensor with 2nd and 3rd dimensions being
62
+ # image width and height final dimension being the number of color
63
+ # channels.
64
+ feature = tf .reshape (feature , [- 1 , 28 , 28 , 1 ])
65
+
66
+ # First conv layer will compute 32 features for each 5x5 patch
67
+ with tf .variable_scope ("conv_layer1" ):
68
+ h_conv1 = layers .conv2d (
69
+ feature ,
70
+ 32 ,
71
+ kernel_size = [5 , 5 ],
72
+ activation = tf .nn .relu ,
73
+ padding = "SAME" ,
74
+ )
75
+ h_pool1 = tf .nn .max_pool (
76
+ h_conv1 , ksize = [1 , 2 , 2 , 1 ], strides = [1 , 2 , 2 , 1 ], padding = "SAME"
77
+ )
78
+
79
+ # Second conv layer will compute 64 features for each 5x5 patch.
80
+ with tf .variable_scope ("conv_layer2" ):
81
+ h_conv2 = layers .conv2d (
82
+ h_pool1 ,
83
+ 64 ,
84
+ kernel_size = [5 , 5 ],
85
+ activation = tf .nn .relu ,
86
+ padding = "SAME" ,
87
+ )
88
+ h_pool2 = tf .nn .max_pool (
89
+ h_conv2 , ksize = [1 , 2 , 2 , 1 ], strides = [1 , 2 , 2 , 1 ], padding = "SAME"
37
90
)
91
+ # reshape tensor into a batch of vectors
92
+ h_pool2_flat = tf .reshape (h_pool2 , [- 1 , 7 * 7 * 64 ])
93
+
94
+ # Densely connected layer with 1024 neurons.
95
+ h_fc1 = layers .dropout (
96
+ layers .dense (h_pool2_flat , 1024 , activation = tf .nn .relu ),
97
+ rate = 0.5 ,
98
+ training = mode == tf .estimator .ModeKeys .TRAIN ,
38
99
)
39
- optimizer = tf .train .GradientDescentOptimizer (0.1 )
40
- optimizer = hvd .DistributedOptimizer (optimizer )
41
- train_step = optimizer .minimize (loss )
42
100
43
- with tf .Session () as sess :
44
- sess .run (tf .global_variables_initializer ())
101
+ # Compute logits (1 per class) and compute loss.
102
+ logits = layers .dense (h_fc1 , 10 , activation = None )
103
+ loss = tf .losses .softmax_cross_entropy (target , logits )
45
104
46
- # Use the elastic wrapper to wrap the function to train one batch
47
- elastic_train_one_batch = elastic_controller .elastic_run (
48
- train_one_batch
49
- )
50
- for i in range (1000 ):
51
- loss_value , _ = elastic_train_one_batch (sess , [loss , train_step ])
52
- logger .info ("loss: {}" .format (loss_value ))
105
+ return tf .argmax (logits , 1 ), loss
106
+
107
+
108
+ def train (args ):
109
+ allreduce_controller = create_elastic_controller (
110
+ batch_size = args .batch_size ,
111
+ num_epochs = args .num_epochs ,
112
+ training_data = args .training_data ,
113
+ )
114
+ dataset = create_dataset (allreduce_controller .data_shard_service )
115
+ dataset = feed (dataset )
116
+ dataset = dataset .batch (args .batch_size ).prefetch (1 )
117
+ dataset_it = dataset .make_one_shot_iterator ()
118
+ batch_x , batch_y = dataset_it .get_next ()
119
+ batch_x = tf .cast (batch_x , tf .float32 )
120
+
121
+ batch_y = tf .reshape (batch_y , (- 1 ,))
122
+ image = tf .reshape (batch_x , (- 1 , 784 ))
123
+ predict , loss = conv_model (image , batch_y , tf .estimator .ModeKeys .TRAIN )
124
+ optimizer = tf .train .GradientDescentOptimizer (0.1 )
125
+ optimizer = DistributedOptimizer (optimizer , fixed_global_batch_size = True )
126
+ global_step = tf .train .get_or_create_global_step ()
127
+ train_step = optimizer .minimize (loss , global_step = global_step )
128
+
129
+ # Use the elastic wrapper to wrap the function to train one batch
130
+ elastic_train_one_batch = allreduce_controller .elastic_run (train_one_batch )
131
+ hook = AdjustBackwardPassesPerStepHook (optimizer )
132
+ allreduce_controller .set_broadcast_variables (tf .global_variables ())
133
+ with allreduce_controller .scope ():
134
+ with tf .train .MonitoredTrainingSession (hooks = [hook ]) as sess :
135
+ allreduce_controller .set_session (sess )
136
+ try :
137
+ while True :
138
+ loss_value , step , _ = elastic_train_one_batch (
139
+ sess , [loss , global_step , train_step ]
140
+ )
141
+ logger .info (
142
+ "global step = {}. loss: {}" .format (step , loss_value )
143
+ )
144
+ except tf .errors .OutOfRangeError :
145
+ print ("end!" )
53
146
54
147
55
148
def train_one_batch (sess , run_tensors ):
56
149
return sess .run (run_tensors )
57
150
58
151
59
- def feed (dataset , mode , _ ):
152
+ def feed (dataset ):
60
153
dataset = dataset .map (_parse_data )
61
-
62
- if mode == Mode .TRAINING :
63
- dataset = dataset .shuffle (buffer_size = 1024 )
154
+ dataset = dataset .shuffle (buffer_size = 1024 )
64
155
return dataset
65
156
66
157
@@ -83,3 +174,30 @@ def eval_metrics_fn():
83
174
tf .cast (tf .reshape (labels , [- 1 ]), tf .int32 ),
84
175
)
85
176
}
177
+
178
+
179
+ def arg_parser ():
180
+ parser = argparse .ArgumentParser (description = "Process training parameters" )
181
+ parser .add_argument ("--batch_size" , type = int , default = 64 , required = False )
182
+ parser .add_argument ("--num_epochs" , type = int , default = 1 , required = False )
183
+ parser .add_argument (
184
+ "--learning_rate" , type = float , default = 0.1 , required = False
185
+ )
186
+ parser .add_argument (
187
+ "--no-cuda" ,
188
+ action = "store_true" ,
189
+ default = False ,
190
+ help = "disable CUDA training" ,
191
+ )
192
+ parser .add_argument ("--training_data" , type = str , required = True )
193
+ parser .add_argument (
194
+ "--validation_data" , type = str , default = "" , required = False
195
+ )
196
+ return parser
197
+
198
+
199
+ if __name__ == "__main__" :
200
+ parser = arg_parser ()
201
+ args = parser .parse_args ()
202
+ print (args )
203
+ train (args )
0 commit comments