Skip to content

Commit 94029ca

Browse files
authored
Remove the argument custom_training_loop (#2509)
* Remove the argument custom_training_loop * Pre-commit * Remove codes * Fix unittests * pre-commit * Fix the docstring
1 parent 8766419 commit 94029ca

File tree

10 files changed

+8
-162
lines changed

10 files changed

+8
-162
lines changed

elasticdl/python/common/model_utils.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -191,40 +191,6 @@ def get_model_spec(
191191
)
192192

193193

194-
def get_training_func_spec(
195-
model_zoo, model_def, feed, custom_data_reader,
196-
):
197-
"""Get the model spec items in a tuple.
198-
199-
Args:
200-
model_zoo: String, the folder name of model files.
201-
model_def: The import path to the model definition function/class in
202-
the "model zoo".
203-
feed: the function name in the model definition file to convert the
204-
input data.
205-
custom_data_reader: the function name in the model definition file
206-
to read data from the storage.
207-
208-
The model spec tuple contains the following items in order:
209-
210-
* The `training_func` of training loop.
211-
* The `feed`,
212-
* The `custom_data_reader`
213-
"""
214-
model_def_module_file = get_module_file_path(model_zoo, model_def)
215-
default_module = load_module(model_def_module_file).__dict__
216-
training_func_name = model_def.split(".")[-1]
217-
training_func = _get_spec_value(
218-
training_func_name, model_zoo, default_module, required=True
219-
)
220-
221-
return (
222-
training_func,
223-
_get_spec_value(feed, model_zoo, default_module, required=False),
224-
_get_spec_value(custom_data_reader, model_zoo, default_module),
225-
)
226-
227-
228194
def find_layer(model, layer_class):
229195
"""
230196
Find all layers in model that are instances of layer_class

elasticdl/python/master/elasticdl_job_service.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,7 @@ def __init__(self, args, task_manager, rendezvous_server=None):
7171
get_module_file_path(args.model_zoo, args.model_def)
7272
).__dict__
7373

74-
self._optimizer = (
75-
None
76-
if args.custom_training_loop
77-
else model_module[args.optimizer]()
78-
)
74+
self._optimizer = model_module[args.optimizer]()
7975

8076
# TODO: Remove task manage and rendezvous server after
8177
# refactoring pod manager.

elasticdl/python/master/task_manager.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,7 @@ def __init__(
155155
args.validation_data, args.data_reader_params
156156
)
157157
self._set_completed_steps_by_checkpoint(args.checkpoint_dir_for_init)
158-
if not args.custom_training_loop:
159-
self._add_deferred_callback_create_train_end_task()
158+
self._add_deferred_callback_create_train_end_task()
160159

161160
self._max_task_completed_times = {
162161
elasticai_api_pb2.EVALUATION: 0,

elasticdl/python/tests/elasticdl_job_service_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def test_create_master_for_allreduce(self):
6262
temp_dir=temp_dir_name,
6363
)
6464
self.arguments["training_data"] = temp_dir_name
65-
self.arguments["custom_training_loop"] = "true"
6665
args = self._get_args()
6766
args = parse_master_args(args)
6867
master = ElasticdlJobService(args, TaskManager(args))
@@ -72,8 +71,7 @@ def test_create_master_without_eval(self):
7271
self.arguments[
7372
"distribution_strategy"
7473
] = DistributionStrategy.ALLREDUCE
75-
self.arguments["custom_training_loop"] = "true"
76-
self.arguments["model_def"] = "mnist.mnist_train_tfv2.train"
74+
self.arguments["model_def"] = "mnist.mnist_functional_api.custom_model"
7775
with tempfile.TemporaryDirectory() as temp_dir_name:
7876
create_recordio_file(
7977
self._num_records,

elasticdl/python/tests/model_utils_test.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
get_model_spec,
2323
get_module_file_path,
2424
get_optimizer_info,
25-
get_training_func_spec,
2625
)
2726

2827
_model_zoo_path = os.path.dirname(os.path.realpath(__file__))
@@ -76,20 +75,6 @@ def test_get_model_spec(self):
7675
callbacks="callbacks",
7776
)
7877

79-
def test_training_func_spec(self):
80-
model_zoo_path = os.path.join(
81-
os.path.dirname(os.path.realpath(__file__)), "../../../model_zoo"
82-
)
83-
(train_spec, feed, data_reader,) = get_training_func_spec(
84-
model_zoo=model_zoo_path,
85-
model_def="mnist.mnist_train_tfv2.train",
86-
feed="feed",
87-
custom_data_reader="custom_data_reader",
88-
)
89-
self.assertIsNotNone(train_spec)
90-
self.assertIsNotNone(feed)
91-
self.assertIsNone(data_reader)
92-
9378
def test_get_module_file_path(self):
9479
self.assertEqual(
9580
get_module_file_path(_model_zoo_path, "test_module.custom_model"),

elasticdl/python/tests/test_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ def __init__(
107107
model_def="",
108108
custom_data_reader="custom_data_reader",
109109
checkpoint_dir_for_init="",
110-
custom_training_loop=False,
111110
task_fault_tolerance=True,
112111
relaunch_timeout_worker=True,
113112
):
@@ -122,7 +121,6 @@ def __init__(
122121
self.model_def = model_def
123122
self.custom_data_reader = custom_data_reader
124123
self.checkpoint_dir_for_init = checkpoint_dir_for_init
125-
self.custom_training_loop = custom_training_loop
126124
self.task_fault_tolerance = task_fault_tolerance
127125
self.relaunch_timeout_worker = relaunch_timeout_worker
128126

elasticdl/python/tests/worker_test.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,8 @@
1616

1717
import tensorflow as tf
1818

19-
from elasticai_api.proto import elasticai_api_pb2
2019
from elasticdl.python.common.args import parse_worker_args
2120
from elasticdl.python.worker.worker import Worker
22-
from elasticdl_client.common.constants import DistributionStrategy
2321

2422

2523
class WorkerTest(unittest.TestCase):
@@ -35,29 +33,6 @@ def _create_worker(self, arguments):
3533
args = parse_worker_args(arguments)
3634
return Worker(args)
3735

38-
def test_init_training_func_from_args(self):
39-
arguments = [
40-
"--worker_id",
41-
"0",
42-
"--job_type",
43-
elasticai_api_pb2.TRAINING,
44-
"--minibatch_size",
45-
self._batch_size,
46-
"--model_zoo",
47-
self._model_zoo_path,
48-
"--model_def",
49-
"mnist.mnist_train_tfv2.train",
50-
"--distribution_strategy",
51-
DistributionStrategy.ALLREDUCE,
52-
"--custom_training_loop",
53-
"true",
54-
]
55-
worker = self._create_worker(arguments)
56-
self.assertIsNotNone(worker._feed)
57-
self.assertIsNotNone(worker._training_func)
58-
self.assertEqual(worker._minibatch_size, 16)
59-
self.assertIsNotNone(worker._task_data_service)
60-
6136

6237
if __name__ == "__main__":
6338
unittest.main()

elasticdl/python/worker/worker.py

Lines changed: 4 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from elasticdl.python.common.model_utils import (
2626
get_dict_from_params_str,
2727
get_model_spec,
28-
get_training_func_spec,
2928
set_callback_parameters,
3029
)
3130
from elasticdl.python.common.timing_utils import Timing
@@ -84,21 +83,16 @@ def __init__(
8483
self._timing = Timing(args.log_level.upper() == "DEBUG", self.logger)
8584
self._log_loss_count = 0
8685
self._var_created = False
87-
self._custom_training_loop = args.custom_training_loop
8886
self._job_type = args.job_type
8987
self._minibatch_size = args.minibatch_size
9088
self._data_shard_service = DataShardService(
9189
self._mc, self._minibatch_size
9290
)
93-
if self._custom_training_loop:
94-
self._init_training_func_from_args(args)
95-
else:
96-
self._init_model_from_args(args)
91+
self._init_model_from_args(args)
9792
self._init_task_data_service(args)
9893
self._init_default_feed_if_needed()
99-
if not self._custom_training_loop:
100-
self._init_callbacks(args)
101-
self._init_trainer(args)
94+
self._init_callbacks(args)
95+
self._init_trainer(args)
10296

10397
def _init_model_from_args(self, args):
10498
"""
@@ -172,19 +166,6 @@ def _init_trainer(self, args):
172166
self._model_inst, self._ps_client, self._timing, args
173167
)
174168

175-
def _init_training_func_from_args(self, args):
176-
self._job_type = args.job_type
177-
(
178-
self._training_func,
179-
self._feed,
180-
self._custom_data_reader,
181-
) = get_training_func_spec(
182-
model_zoo=args.model_zoo,
183-
model_def=args.model_def,
184-
feed=args.feed,
185-
custom_data_reader=args.custom_data_reader,
186-
)
187-
188169
def _init_default_feed_if_needed(self):
189170
if self._feed is None:
190171
if hasattr(self._task_data_service.data_reader, "default_feed"):
@@ -465,48 +446,4 @@ def run(self):
465446
elif self._job_type == JobType.EVALUATION_ONLY:
466447
self._evaluate_only()
467448
else:
468-
if self._custom_training_loop:
469-
self._elastic_allreduce_train()
470-
else:
471-
self._train_and_evaluate()
472-
473-
def _elastic_allreduce_train(self):
474-
"""
475-
Train and evaluate the model on the worker
476-
"""
477-
if os.getenv("USE_TORCH", None):
478-
from elasticai_api.pytorch.controller import (
479-
PyTorchAllReduceController,
480-
)
481-
482-
elastic_controller = PyTorchAllReduceController(
483-
self._mc, self._data_shard_service
484-
)
485-
elif _IS_TF2:
486-
from elasticai_api.tensorflow.controller import (
487-
TensorFlowV2AllReduceController,
488-
)
489-
490-
elastic_controller = TensorFlowV2AllReduceController(
491-
self._mc, self._data_shard_service
492-
)
493-
else:
494-
from elasticai_api.tensorflow.controller import (
495-
TensorFlowV1AllReduceController,
496-
)
497-
498-
elastic_controller = TensorFlowV1AllReduceController(
499-
self._mc, self._master_addr
500-
)
501-
# Initialize Horovod locally to generate varibles of the model
502-
# and optimizer.
503-
elastic_controller.init_horovod_locally()
504-
dataset = self._task_data_service.get_dataset()
505-
dataset = self._feed(
506-
dataset,
507-
Mode.TRAINING,
508-
self._task_data_service.data_reader.metadata,
509-
)
510-
dataset = dataset.batch(self._minibatch_size).prefetch(1)
511-
self._training_func(dataset, elastic_controller)
512-
del dataset
449+
self._train_and_evaluate()

elasticdl_client/common/args.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,13 +140,6 @@ def add_train_params(parser):
140140
help="If True, PS will modulate the learning rate with staleness "
141141
"in asynchronous SGD",
142142
)
143-
add_bool_param(
144-
parser=parser,
145-
name="--custom_training_loop",
146-
default=False,
147-
help="If true, users need to define training loop by themselves. "
148-
"Otherwise, users should define a Keras model",
149-
)
150143
add_bool_param(
151144
parser=parser,
152145
name="--need_elasticdl_job_service",

model_zoo/mnist/mnist_pytorch.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
Download the mnist dataset from
1616
https://s3.amazonaws.com/fast-ai-imageclas/mnist_png.tgz
1717
and then untar it into ${data_store_dir}. Using minikube, we can use the
18-
following command to submit a training job with the script.
18+
following command to submit a training job with these codes.
1919
2020
elasticdl train \
2121
--image_name=elasticdl:pt_mnist_allreduce \
@@ -33,7 +33,6 @@
3333
--job_name=test-mnist-allreduce \
3434
--image_pull_policy=Never \
3535
--volume="host_path=${data_store_dir},mount_path=/local_data" \
36-
--custom_training_loop=true \
3736
--distribution_strategy=AllreduceStrategy \
3837
"""
3938

0 commit comments

Comments
 (0)