Skip to content

Commit 0c00b66

Browse files
authored
Worker can report training data to the master if using RecordIO (#2494)
* Worker can report training data to the master if using recordio * Pre-commit * Fix by comments
1 parent 6f8754c commit 0c00b66

File tree

6 files changed

+49
-5
lines changed

6 files changed

+49
-5
lines changed

elasticai_api/common/data_shard_service.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(
5454
shuffle_shards=False,
5555
task_type=elasticai_api_pb2.TRAINING,
5656
num_minibatches_per_shard=0,
57+
training_data=None,
5758
):
5859
self._mc = master_client
5960
self._batch_size = batch_size
@@ -63,6 +64,7 @@ def __init__(
6364
self._shuffle_shards = shuffle_shards
6465
self._task_type = task_type
6566
self._num_minibatches_per_shard = num_minibatches_per_shard
67+
self._training_data = training_data
6668
self._lock = threading.Lock()
6769
self._failed_record_count = 0
6870
self._reported_record_count = 0
@@ -79,6 +81,7 @@ def _report_training_params(self):
7981
shuffle=self._shuffle,
8082
shuffle_shards=self._shuffle_shards,
8183
num_minibatches_per_shard=self._num_minibatches_per_shard,
84+
training_data=self._training_data,
8285
)
8386

8487
def get_minibatch_count_per_epoch(self):
@@ -167,6 +170,7 @@ def __init__(
167170
dataset_size=None,
168171
task_type=elasticai_api_pb2.TRAINING,
169172
shuffle=False,
173+
training_data=None,
170174
):
171175
super(RecordIndexService, self).__init__(
172176
master_client=master_client,
@@ -175,6 +179,7 @@ def __init__(
175179
dataset_size=dataset_size,
176180
shuffle=shuffle,
177181
task_type=task_type,
182+
training_data=training_data,
178183
)
179184
self._shard_queue = SimpleQueue()
180185
threading.Thread(

elasticai_api/common/master_client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def report_training_params(
118118
shuffle=False,
119119
shuffle_shards=False,
120120
num_minibatches_per_shard=0,
121+
training_data=None,
121122
):
122123
request = elasticai_api_pb2.ReportTrainingParamsRequest()
123124
request.batch_size = batch_size
@@ -127,5 +128,7 @@ def report_training_params(
127128
request.num_epochs = num_epochs
128129
if dataset_size is not None:
129130
request.dataset_size = dataset_size
131+
if training_data is not None:
132+
request.training_data = training_data
130133
request.num_minibatches_per_shard = num_minibatches_per_shard
131134
return self._stub.report_training_params(request)

elasticai_api/proto/elasticai_api.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ message ReportTrainingParamsRequest {
6060
bool shuffle = 4;
6161
bool shuffle_shards = 5;
6262
int32 num_minibatches_per_shard = 6;
63+
string training_data = 7;
6364
}
6465

6566
message GetTaskRequest {

elasticdl/python/master/servicer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def report_training_params(self, request, _):
174174
request.shuffle,
175175
request.shuffle_shards,
176176
request.num_minibatches_per_shard,
177+
request.training_data,
177178
)
178179
return empty_pb2.Empty()
179180

elasticdl/python/master/task_manager.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
"""TaskQueue Implementation"""
1515

16+
import os
1617
import random
1718
import threading
1819
import time
@@ -229,17 +230,20 @@ def set_training_params(
229230
shuffle,
230231
shuffle_shards,
231232
num_minibatches_per_shard,
233+
training_data=None,
232234
):
233235
logger.info(
234236
"Set training parameters: "
235237
"batch_size={}, num_epochs={}, dataset_size={}, shuffle={}, "
236-
"shuffle_shards={}, num_minibatches_per_shard={}".format(
238+
"shuffle_shards={}, num_minibatches_per_shard={}, "
239+
"training_data={}".format(
237240
batch_size,
238241
num_epochs,
239242
dataset_size,
240243
shuffle,
241244
shuffle_shards,
242245
num_minibatches_per_shard,
246+
training_data,
243247
)
244248
)
245249

@@ -270,9 +274,20 @@ def set_training_params(
270274
self._dataset_size = (
271275
dataset_size if dataset_size > 0 else self._dataset_size
272276
)
273-
self._training_shards = self._create_shards_by_dataset_size(
274-
dataset_size
275-
)
277+
if (
278+
not dataset_size
279+
and training_data
280+
and os.path.isdir(training_data)
281+
):
282+
# The training_data is a directory only using RecordIO.
283+
self._create_training_tasks(training_data, None)
284+
elif dataset_size > 0:
285+
self._training_shards = self._create_shards_by_dataset_size(
286+
dataset_size
287+
)
288+
else:
289+
logger.warning("No data to create shards")
290+
276291
if self._training_shards:
277292
logger.info("Starting epoch %d", self._epoch)
278293
self.create_tasks(elasticai_api_pb2.TRAINING)

elasticdl/python/tests/task_manager_test.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,17 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14+
import tempfile
1415
import threading
1516
import time
1617
import unittest
1718

1819
from elasticai_api.proto import elasticai_api_pb2
19-
from elasticdl.python.tests.test_utils import create_task_manager
20+
from elasticdl.python.tests.test_utils import (
21+
DatasetName,
22+
create_recordio_file,
23+
create_task_manager,
24+
)
2025

2126

2227
class TaskManagerTest(unittest.TestCase):
@@ -176,6 +181,20 @@ def test_set_training_params(self):
176181
)
177182
self.assertEqual(len(task_manager._todo), 4)
178183

184+
task_manager = create_task_manager([], [])
185+
num_records = 128
186+
with tempfile.TemporaryDirectory() as temp_dir_name:
187+
shard_name = create_recordio_file(
188+
num_records, DatasetName.TEST_MODULE, 1, temp_dir=temp_dir_name
189+
)
190+
191+
task_manager.set_training_params(
192+
1, 1, 0, False, False, 3, temp_dir_name
193+
)
194+
self.assertEqual(
195+
task_manager._training_shards, [(shard_name, 0, num_records)]
196+
)
197+
179198

180199
if __name__ == "__main__":
181200
unittest.main()

0 commit comments

Comments
 (0)