Skip to content

Commit 3d1ddcb

Browse files
authored
Support customize training loop and dataset using TensorFlow 1.x (#2500)
* Worker can report training data to the master if using recordio * Add logs * Create a TensorFlow operator * Warp the job command using parentheses * Pre-commit * Fix training loop * Pre-commit * Develop a TensorFlow V1 model * Fix the bug if the host of a new worker is the same as an old worker * Do not call on_pod_deleted if the status is from FAILED to DELETED * Fix the counter when retring * Fix backward_passes_per_step * Pre-commit * Pre-commit * Delete unused logs
1 parent b296657 commit 3d1ddcb

File tree

3 files changed

+208
-38
lines changed

3 files changed

+208
-38
lines changed

elasticai_api/common/data_shard_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(
7373
self._report_training_params()
7474

7575
def _report_training_params(self):
76-
if self._num_epochs and self._dataset_size:
76+
if self._num_epochs and (self._dataset_size or self._training_data):
7777
self._mc.report_training_params(
7878
batch_size=self._batch_size,
7979
num_epochs=self._num_epochs,

elasticai_api/tensorflow/controller.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# limitations under the License.
1313

1414
import time
15+
from distutils.version import LooseVersion
1516

1617
import tensorflow as tf
1718
from tensorflow.python.framework.errors_impl import UnknownError
@@ -21,6 +22,8 @@
2122
RETRY_ALLREDUCE_INTERVAL_SECS,
2223
AllReduceController,
2324
)
25+
from elasticai_api.common.data_shard_service import RecordIndexService
26+
from elasticai_api.common.master_client import build_master_client
2427
from elasticai_api.util.log_utils import default_logger as logger
2528

2629
try:
@@ -30,6 +33,50 @@
3033
except ImportError:
3134
hvd = None
3235

36+
_IS_TF2 = LooseVersion(tf.__version__) >= LooseVersion("2.0.0")
37+
38+
39+
def create_elastic_controller(
40+
batch_size,
41+
num_epochs=None,
42+
dataset_size=None,
43+
shuffle=False,
44+
training_data=None,
45+
):
46+
"""Create an elastic AllReduce controller with data shard service.
47+
Users can use the `controller.data_shard_service` to get data
48+
shards like:
49+
```python
50+
shard = controller.data_shard_service.fetch_shard()
51+
```
52+
53+
Users also can use the controller to do an elastic training.
54+
55+
Args:
56+
batch_size: The batch size of a single worker.
57+
num_epochs: The number of epochs.
58+
dataset_size: The total size of dataset.
59+
"""
60+
master_client = build_master_client()
61+
record_index_service = RecordIndexService(
62+
master_client=master_client,
63+
batch_size=batch_size,
64+
num_epochs=num_epochs,
65+
dataset_size=dataset_size,
66+
shuffle=shuffle,
67+
training_data=training_data,
68+
)
69+
if _IS_TF2:
70+
controller = TensorFlowV2AllReduceController(
71+
master_client, record_index_service
72+
)
73+
else:
74+
controller = TensorFlowV1AllReduceController(
75+
master_client, record_index_service
76+
)
77+
controller.init_horovod_locally()
78+
return controller
79+
3380

3481
class TensorFlowV2AllReduceController(AllReduceController):
3582
"""The controller is responsible for elastic training of
@@ -87,13 +134,18 @@ def __init__(self, master_client, master_addr):
87134
master_client, master_addr
88135
)
89136
self._bcast_op = None
137+
self._session = None
90138

91-
def broadcast(self):
139+
def set_broadcast_variables(self, variables):
92140
if self._bcast_op is None:
93-
self._variables = tf.global_variables()
141+
self._variables = variables
94142
self._bcast_op = broadcast_variables(self._variables, root_rank=0)
95-
session = tf.get_default_session()
96-
session.run(self._bcast_op)
143+
144+
def set_session(self, session):
145+
self._session = session
146+
147+
def broadcast(self):
148+
self._session.run(self._bcast_op)
97149

98150
def train_one_batch_with_retries(self, func, *args, **kwargs):
99151
allreduce_success = False

model_zoo/mnist/mnist_train_tfv1.py

Lines changed: 151 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,56 +11,147 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14-
import horovod.tensorflow as hvd
14+
import argparse
15+
from contextlib import closing
16+
17+
import recordio
1518
import tensorflow as tf
1619

17-
from elasticdl.python.common.constants import Mode
20+
from elasticai_api.tensorflow.controller import create_elastic_controller
21+
from elasticai_api.tensorflow.optimizer import (
22+
AdjustBackwardPassesPerStepHook,
23+
DistributedOptimizer,
24+
)
1825
from elasticdl.python.common.log_utils import default_logger as logger
1926

27+
layers = tf.layers
2028

21-
def train(dataset, elastic_controller):
22-
dataset_it = dataset.make_one_shot_iterator()
23-
batch_x, batch_y = dataset_it.get_next()
24-
batch_x = tf.cast(batch_x, tf.float32)
2529

26-
x = tf.keras.layers.Reshape((28, 28, 1))(batch_x)
27-
x = tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu")(x)
28-
x = tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu")(x)
29-
x = tf.keras.layers.BatchNormalization()(x)
30-
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x)
31-
x = tf.keras.layers.Dropout(0.25)(x)
32-
x = tf.keras.layers.Flatten()(x)
33-
outputs = tf.keras.layers.Dense(10)(x)
34-
loss = tf.reduce_mean(
35-
input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits(
36-
logits=outputs, labels=tf.reshape(batch_y, [-1])
30+
def get_dataset_gen(data_shard_service):
31+
def gen():
32+
while True:
33+
shard = data_shard_service.fetch_shard()
34+
if not shard:
35+
raise StopIteration("No data")
36+
with closing(
37+
recordio.Scanner(
38+
shard.name, shard.start, shard.end - shard.start,
39+
)
40+
) as reader:
41+
for i in range(shard.start, shard.end):
42+
record = reader.record()
43+
if record:
44+
yield record
45+
46+
return gen
47+
48+
49+
def create_dataset(data_shard_service):
50+
gen = get_dataset_gen(data_shard_service)
51+
dataset = tf.data.Dataset.from_generator(gen, tf.string)
52+
return dataset
53+
54+
55+
def conv_model(feature, target, mode):
56+
"""2-layer convolution model."""
57+
# Convert the target to a one-hot tensor of shape (batch_size, 10) and
58+
# with a on-value of 1 for each one-hot vector of length 10.
59+
target = tf.one_hot(tf.cast(target, tf.int32), 10, 1, 0)
60+
61+
# Reshape feature to 4d tensor with 2nd and 3rd dimensions being
62+
# image width and height final dimension being the number of color
63+
# channels.
64+
feature = tf.reshape(feature, [-1, 28, 28, 1])
65+
66+
# First conv layer will compute 32 features for each 5x5 patch
67+
with tf.variable_scope("conv_layer1"):
68+
h_conv1 = layers.conv2d(
69+
feature,
70+
32,
71+
kernel_size=[5, 5],
72+
activation=tf.nn.relu,
73+
padding="SAME",
74+
)
75+
h_pool1 = tf.nn.max_pool(
76+
h_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME"
77+
)
78+
79+
# Second conv layer will compute 64 features for each 5x5 patch.
80+
with tf.variable_scope("conv_layer2"):
81+
h_conv2 = layers.conv2d(
82+
h_pool1,
83+
64,
84+
kernel_size=[5, 5],
85+
activation=tf.nn.relu,
86+
padding="SAME",
87+
)
88+
h_pool2 = tf.nn.max_pool(
89+
h_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME"
3790
)
91+
# reshape tensor into a batch of vectors
92+
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
93+
94+
# Densely connected layer with 1024 neurons.
95+
h_fc1 = layers.dropout(
96+
layers.dense(h_pool2_flat, 1024, activation=tf.nn.relu),
97+
rate=0.5,
98+
training=mode == tf.estimator.ModeKeys.TRAIN,
3899
)
39-
optimizer = tf.train.GradientDescentOptimizer(0.1)
40-
optimizer = hvd.DistributedOptimizer(optimizer)
41-
train_step = optimizer.minimize(loss)
42100

43-
with tf.Session() as sess:
44-
sess.run(tf.global_variables_initializer())
101+
# Compute logits (1 per class) and compute loss.
102+
logits = layers.dense(h_fc1, 10, activation=None)
103+
loss = tf.losses.softmax_cross_entropy(target, logits)
45104

46-
# Use the elastic wrapper to wrap the function to train one batch
47-
elastic_train_one_batch = elastic_controller.elastic_run(
48-
train_one_batch
49-
)
50-
for i in range(1000):
51-
loss_value, _ = elastic_train_one_batch(sess, [loss, train_step])
52-
logger.info("loss: {}".format(loss_value))
105+
return tf.argmax(logits, 1), loss
106+
107+
108+
def train(args):
109+
allreduce_controller = create_elastic_controller(
110+
batch_size=args.batch_size,
111+
num_epochs=args.num_epochs,
112+
training_data=args.training_data,
113+
)
114+
dataset = create_dataset(allreduce_controller.data_shard_service)
115+
dataset = feed(dataset)
116+
dataset = dataset.batch(args.batch_size).prefetch(1)
117+
dataset_it = dataset.make_one_shot_iterator()
118+
batch_x, batch_y = dataset_it.get_next()
119+
batch_x = tf.cast(batch_x, tf.float32)
120+
121+
batch_y = tf.reshape(batch_y, (-1,))
122+
image = tf.reshape(batch_x, (-1, 784))
123+
predict, loss = conv_model(image, batch_y, tf.estimator.ModeKeys.TRAIN)
124+
optimizer = tf.train.GradientDescentOptimizer(0.1)
125+
optimizer = DistributedOptimizer(optimizer, fixed_global_batch_size=True)
126+
global_step = tf.train.get_or_create_global_step()
127+
train_step = optimizer.minimize(loss, global_step=global_step)
128+
129+
# Use the elastic wrapper to wrap the function to train one batch
130+
elastic_train_one_batch = allreduce_controller.elastic_run(train_one_batch)
131+
hook = AdjustBackwardPassesPerStepHook(optimizer)
132+
allreduce_controller.set_broadcast_variables(tf.global_variables())
133+
with allreduce_controller.scope():
134+
with tf.train.MonitoredTrainingSession(hooks=[hook]) as sess:
135+
allreduce_controller.set_session(sess)
136+
try:
137+
while True:
138+
loss_value, step, _ = elastic_train_one_batch(
139+
sess, [loss, global_step, train_step]
140+
)
141+
logger.info(
142+
"global step = {}. loss: {}".format(step, loss_value)
143+
)
144+
except tf.errors.OutOfRangeError:
145+
print("end!")
53146

54147

55148
def train_one_batch(sess, run_tensors):
56149
return sess.run(run_tensors)
57150

58151

59-
def feed(dataset, mode, _):
152+
def feed(dataset):
60153
dataset = dataset.map(_parse_data)
61-
62-
if mode == Mode.TRAINING:
63-
dataset = dataset.shuffle(buffer_size=1024)
154+
dataset = dataset.shuffle(buffer_size=1024)
64155
return dataset
65156

66157

@@ -83,3 +174,30 @@ def eval_metrics_fn():
83174
tf.cast(tf.reshape(labels, [-1]), tf.int32),
84175
)
85176
}
177+
178+
179+
def arg_parser():
180+
parser = argparse.ArgumentParser(description="Process training parameters")
181+
parser.add_argument("--batch_size", type=int, default=64, required=False)
182+
parser.add_argument("--num_epochs", type=int, default=1, required=False)
183+
parser.add_argument(
184+
"--learning_rate", type=float, default=0.1, required=False
185+
)
186+
parser.add_argument(
187+
"--no-cuda",
188+
action="store_true",
189+
default=False,
190+
help="disable CUDA training",
191+
)
192+
parser.add_argument("--training_data", type=str, required=True)
193+
parser.add_argument(
194+
"--validation_data", type=str, default="", required=False
195+
)
196+
return parser
197+
198+
199+
if __name__ == "__main__":
200+
parser = arg_parser()
201+
args = parser.parse_args()
202+
print(args)
203+
train(args)

0 commit comments

Comments
 (0)