Skip to content

Commit 40f5392

Browse files
authored
Develop an API to get training epoch (#2488)
* Check whether to register hooks according to HOROVOD_ELASTIC * Develop an API to get training epoch * Register hooks * Add unittest * Fic by comments * Fix unittest
1 parent 37c8c8f commit 40f5392

File tree

4 files changed

+27
-11
lines changed

4 files changed

+27
-11
lines changed

elasticai_api/common/data_shard_service.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ def _report_training_params(self):
8181
num_minibatches_per_shard=self._num_minibatches_per_shard,
8282
)
8383

84+
def get_minibatch_count_per_epoch(self):
85+
return self._dataset_size // self._batch_size
86+
8487
def get_current_task(self):
8588
return self._current_task
8689

elasticai_api/pytorch/controller.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,15 @@ def __init__(self, master_client, data_shard_service):
107107
os.getenv(WorkerEnv.WORKER_NUM, 1)
108108
)
109109
self.global_completed_batch_num = 0
110+
self.batch_count_per_epoch = (
111+
self.data_shard_service.get_minibatch_count_per_epoch()
112+
)
113+
114+
def get_current_epoch(self):
115+
return self.global_completed_batch_num // self.batch_count_per_epoch
116+
117+
def set_resume_epoch(self, epoch):
118+
self.global_completed_batch_num = epoch * self.batch_count_per_epoch
110119

111120
def set_broadcast_model(self, model):
112121
self._model = model
@@ -175,8 +184,8 @@ def reset_backward_passes_per_step(self):
175184
):
176185
world_size = hvd.size()
177186
rank = hvd.rank()
178-
self.backward_passes_per_step = int(
179-
self.global_batch_num_per_step / world_size
187+
self.backward_passes_per_step = (
188+
self.global_batch_num_per_step // world_size
180189
)
181190
if rank < self.global_batch_num_per_step % world_size:
182191
self.backward_passes_per_step += 1

elasticdl/python/tests/allreduce_trainer_test.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def test_elastic_run(self):
100100
rendezvous_id=1, rank_id=0, world_size=1, rendezvous_port=0
101101
)
102102
)
103-
data_shard_service = DataShardService(master_client, 1)
103+
data_shard_service = DataShardService(master_client, 1, 1, 10)
104104
controller = AllReduceController(master_client, data_shard_service)
105105
elastic_run = controller.elastic_run(self.train)
106106
elastic_run()
@@ -115,7 +115,7 @@ def setUp(self):
115115
rendezvous_id=1, rank_id=0, world_size=1, rendezvous_port=0
116116
)
117117
)
118-
data_shard_service = DataShardService(master_client, 1)
118+
data_shard_service = DataShardService(master_client, 1, 1, 10)
119119
self.controller = TensorFlowV2AllReduceController(
120120
master_client, data_shard_service
121121
)
@@ -145,7 +145,7 @@ def setUp(self):
145145
rendezvous_id=1, rank_id=0, world_size=1, rendezvous_port=0
146146
)
147147
)
148-
data_shard_service = DataShardService(master_client, 1)
148+
data_shard_service = DataShardService(master_client, 1, 1, 10)
149149
self.controller = PyTorchAllReduceController(
150150
master_client, data_shard_service
151151
)
@@ -171,7 +171,7 @@ def test_elastic_run(self):
171171
self.assertEqual(self.controller.global_completed_batch_num, 1)
172172

173173
def test_create_elastic_controller(self):
174-
controller = create_elastic_controller(batch_size=64)
174+
controller = create_elastic_controller(batch_size=64, dataset_size=128)
175175
self.assertIsNotNone(controller)
176176
self.assertIsNotNone(controller.data_shard_service._mc)
177177
self.assertEqual(controller.data_shard_service._batch_size, 64)
@@ -187,6 +187,14 @@ def test_reset_backward_passes_per_step(self):
187187
self.controller.reset_backward_passes_per_step()
188188
self.assertEqual(self.controller.backward_passes_per_step, 2)
189189

190+
def test_get_epoch(self):
191+
self.controller.batch_count_per_epoch = 10
192+
self.controller.global_completed_batch_num = 78
193+
epoch = self.controller.get_current_epoch()
194+
self.assertEqual(epoch, 7)
195+
self.controller.set_resume_epoch(5)
196+
self.assertEqual(self.controller.global_completed_batch_num, 50)
197+
190198

191199
if __name__ == "__main__":
192200
unittest.main()

model_zoo/mnist/mnist_pytorch.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ def train(args):
131131
device = torch.device("cuda" if use_cuda else "cpu")
132132
train_data = torchvision.datasets.ImageFolder(args.training_data)
133133
test_data = torchvision.datasets.ImageFolder(args.validation_data)
134-
batch_num_per_epoch = int(len(train_data.imgs) / args.batch_size)
135134

136135
allreduce_controller = create_elastic_controller(
137136
batch_size=args.batch_size,
@@ -171,10 +170,7 @@ def train(args):
171170
data, target = data.to(device), target.to(device)
172171
loss = elastic_train_one_batch(model, optimizer, data, target)
173172
print("loss = {}, step = {}".format(loss, batch_idx))
174-
new_epoch = int(
175-
allreduce_controller.global_completed_batch_num
176-
/ batch_num_per_epoch
177-
)
173+
new_epoch = allreduce_controller.get_current_epoch()
178174
if new_epoch > epoch:
179175
epoch = new_epoch
180176
# Set epoch of the scheduler

0 commit comments

Comments
 (0)