Skip to content

Commit 33b98f4

Browse files
saumishrfacebook-github-bot
authored andcommitted
Use Gloo PG if available for both restore and restore_with_id methods (#897)
Summary: Pull Request resolved: #897 Use Gloo PG if available for both restore and restore_with_id methods. This diff moves the logic to restore_with_id which gets called by the restore method. This will ensure that it takes effect for both the code paths. Reviewed By: JKSenthil Differential Revision: D62539308 fbshipit-source-id: bb37c2ce0e33027967c7ef5727ca09c3ec491fc6
1 parent 665dd50 commit 33b98f4

File tree

2 files changed

+98
-24
lines changed

2 files changed

+98
-24
lines changed

tests/framework/callbacks/test_dcp_saver_gpu.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,30 @@ def _test_gloo_pg_restore(
4646
tc.assertEqual(dist.get_backend(process_group), dist.Backend.GLOO, None)
4747
mock_destroy_process_group.assert_called_once()
4848

49+
@skip_if_not_distributed
50+
@skip_if_not_gpu
51+
def test_test_gloo_pg_restore_wth_id(self) -> None:
52+
spawn_multi_process(
53+
1,
54+
"nccl",
55+
self._test_gloo_pg_restore,
56+
)
57+
58+
@staticmethod
59+
@patch("torch.distributed.destroy_process_group")
60+
@patch("torchtnt.framework.callbacks.dcp_saver.dcp")
61+
def _test_gloo_pg_restore_with_id(
62+
mock_dist_cp: MagicMock, mock_destroy_process_group: MagicMock
63+
) -> None:
64+
tc = unittest.TestCase()
65+
my_unit = DummyAutoUnit(module=nn.Linear(2, 3))
66+
DistributedCheckpointSaver.restore_with_id(
67+
checkpoint_id="path/to/snapshot", unit=my_unit
68+
)
69+
process_group = mock_dist_cp.load.call_args.kwargs["process_group"]
70+
tc.assertEqual(dist.get_backend(process_group), dist.Backend.GLOO, None)
71+
mock_destroy_process_group.assert_called_once()
72+
4973
@skip_if_not_distributed
5074
@skip_if_not_gpu
5175
def test_save_restore_fsdp(self) -> None:
@@ -94,3 +118,52 @@ def _save_restore_fsdp() -> None:
94118
finally:
95119
if get_global_rank() == 0:
96120
shutil.rmtree(temp_dir) # delete temp directory
121+
122+
@skip_if_not_distributed
123+
@skip_if_not_gpu
124+
def test_save_restore_fsdp_with_id(self) -> None:
125+
spawn_multi_process(
126+
2,
127+
"nccl",
128+
self._save_restore_fsdp_with_id,
129+
)
130+
131+
@staticmethod
132+
def _save_restore_fsdp_with_id() -> None:
133+
input_dim = 2
134+
dataset_len = 10
135+
batch_size = 2
136+
max_epochs = 2
137+
save_every_n_epochs = 1
138+
139+
my_unit = DummyAutoUnit(module=torch.nn.Linear(input_dim, 2), strategy="fsdp")
140+
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
141+
if get_global_rank() == 0:
142+
temp_dir = tempfile.mkdtemp()
143+
else:
144+
temp_dir = ""
145+
146+
dcp_cb = DistributedCheckpointSaver(
147+
temp_dir,
148+
save_every_n_epochs=save_every_n_epochs,
149+
)
150+
temp_dir = dcp_cb.dirpath
151+
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[dcp_cb])
152+
153+
tc = unittest.TestCase()
154+
try:
155+
my_new_unit = DummyAutoUnit(
156+
module=torch.nn.Linear(input_dim, 2), strategy="fsdp"
157+
)
158+
tc.assertNotEqual(
159+
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
160+
)
161+
# get latest checkpoint
162+
ckpt_path = os.path.join(temp_dir, f"epoch_{max_epochs}_train_step_10")
163+
dcp_cb.restore_with_id(checkpoint_id=ckpt_path, unit=my_new_unit)
164+
tc.assertEqual(
165+
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
166+
)
167+
finally:
168+
if get_global_rank() == 0:
169+
shutil.rmtree(temp_dir) # delete temp directory

torchtnt/framework/callbacks/dcp_saver.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -230,39 +230,19 @@ def restore(
230230
) -> None:
231231
"""Utility method to restore dcp checkpoint from a path."""
232232

233-
# use gloo pg if available
234-
gloo_pg_created = False
235-
if dist.is_initialized():
236-
pg = dist.group.WORLD if process_group is None else process_group
237-
238-
if dist.get_backend(pg) != dist.Backend.GLOO:
239-
rank_zero_info(
240-
"Creating new gloo process group for loading checkpoint."
241-
)
242-
pg = dist.new_group(
243-
timeout=timedelta(seconds=3600), backend=dist.Backend.GLOO
244-
)
245-
gloo_pg_created = True
246-
else:
247-
pg = process_group
248-
249233
checkpoint_id = path
250234

251235
DistributedCheckpointSaver.restore_with_id(
252236
checkpoint_id,
253237
unit,
254238
train_dataloader=train_dataloader,
255-
process_group=pg,
239+
process_group=process_group,
256240
restore_options=restore_options,
257241
knob_options=knob_options,
258242
planner=planner,
259243
storage_reader=storage_reader,
260244
)
261245

262-
# destroy gloo pg if created, its sole purpose was for checkpoint restore
263-
if gloo_pg_created:
264-
dist.destroy_process_group(pg)
265-
266246
@staticmethod
267247
def restore_with_id(
268248
checkpoint_id: Union[int, str],
@@ -284,15 +264,32 @@ def restore_with_id(
284264
checkpoint_id: Checkpoint id. It can be the path of the snapshot to restore.
285265
unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore.
286266
train_dataloader: An optional train dataloader to restore.
287-
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world) Note:
288-
If torch.distributed is available and a process group is initialized, dcp assumes the intention is to save/load checkpoints in distributed fashion.
267+
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
268+
If not Gloo, a Gloo process group is created.
269+
Note: If torch.distributed is available and a process group is initialized, dcp assumes the intention is to save/load checkpoints in distributed fashion.
289270
restore_options: Controls what to filter when restoring the state.
290271
knob_options: Additional keyword options for StorageWriter and StorageReader
291272
planner: Instance of LoadPlanner. If this is not specificed, the default planner will be used. (Default: ``None``)
292273
storage_reader: Instance of StorageReader used to perform reads. If this is not specified, it will automatically infer
293274
the reader based on the checkpoint_id. If checkpoint_id is also None, an exception will be raised. (Default: ``None``)
294275
"""
295276

277+
# use gloo pg if available
278+
gloo_pg_created = False
279+
if dist.is_initialized():
280+
pg = dist.group.WORLD if process_group is None else process_group
281+
282+
if dist.get_backend(pg) != dist.Backend.GLOO:
283+
rank_zero_info(
284+
"Creating new gloo process group for loading checkpoint."
285+
)
286+
pg = dist.new_group(
287+
timeout=timedelta(seconds=3600), backend=dist.Backend.GLOO
288+
)
289+
gloo_pg_created = True
290+
else:
291+
pg = process_group
292+
296293
restore_options = restore_options or RestoreOptions()
297294
app_state = _prepare_app_state_for_restore(unit, restore_options)
298295
checkpoint_id = str(checkpoint_id)
@@ -340,13 +337,17 @@ def restore_with_id(
340337
checkpoint_id=checkpoint_id,
341338
storage_reader=storage_reader,
342339
planner=planner,
343-
process_group=process_group,
340+
process_group=pg,
344341
)
345342

346343
rank_zero_info(
347344
f"Restored snapshot for checkpoint_id: {checkpoint_id}", logger=logger
348345
)
349346

347+
# destroy gloo pg if created, its sole purpose was for checkpoint restore
348+
if gloo_pg_created:
349+
dist.destroy_process_group(pg)
350+
350351
def _generate_checkpoint_and_upkeep(
351352
self, state: State, unit: Union[TTrainUnit, TEvalUnit], hook: str
352353
) -> bool:

0 commit comments

Comments
 (0)