Skip to content

Commit 544a225

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
create gloo pg for DCPSaver.restore() (#874)
Summary: Pull Request resolved: #874 Reviewed By: galrotem Differential Revision: D60408282 fbshipit-source-id: a0aaf117203ed6dd1f5c6e79a955a1bd8f855821
1 parent 5c73bd5 commit 544a225

File tree

3 files changed

+57
-1
lines changed

3 files changed

+57
-1
lines changed

tests/framework/callbacks/test_dcp_saver.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,17 @@ def test_restore_allow_partial_loading(self, mock_dist_cp: MagicMock) -> None:
455455
].allow_partial_load
456456
self.assertFalse(allow_partial_load)
457457

458+
@patch("torch.distributed.destroy_process_group")
459+
@patch("torchtnt.framework.callbacks.dcp_saver.dcp")
460+
def test_gloo_pg_restore(
461+
self, mock_dist_cp: MagicMock, mock_destroy_process_group: MagicMock
462+
) -> None:
463+
my_unit = DummyAutoUnit(module=nn.Linear(2, 3))
464+
DistributedCheckpointSaver.restore(path="path/to/snapshot", unit=my_unit)
465+
process_group = mock_dist_cp.load.call_args.kwargs["process_group"]
466+
self.assertEqual(process_group, None)
467+
mock_destroy_process_group.assert_not_called()
468+
458469

459470
class DummyStatefulDataLoader:
460471
def __init__(self, dataloader: DataLoader) -> None:

tests/framework/callbacks/test_dcp_saver_gpu.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
import shutil
1212
import tempfile
1313
import unittest
14+
from unittest.mock import MagicMock, patch
1415

1516
import torch
17+
from torch import distributed as dist, nn
1618

1719
from torchtnt.framework._test_utils import DummyAutoUnit, generate_random_dataloader
1820
from torchtnt.framework.callbacks.dcp_saver import DistributedCheckpointSaver
@@ -22,6 +24,28 @@
2224

2325

2426
class DistributedCheckpointSaverGPUTest(unittest.TestCase):
27+
@skip_if_not_distributed
28+
@skip_if_not_gpu
29+
def test_test_gloo_pg_restore(self) -> None:
30+
spawn_multi_process(
31+
1,
32+
"nccl",
33+
self._test_gloo_pg_restore,
34+
)
35+
36+
@staticmethod
37+
@patch("torch.distributed.destroy_process_group")
38+
@patch("torchtnt.framework.callbacks.dcp_saver.dcp")
39+
def _test_gloo_pg_restore(
40+
mock_dist_cp: MagicMock, mock_destroy_process_group: MagicMock
41+
) -> None:
42+
tc = unittest.TestCase()
43+
my_unit = DummyAutoUnit(module=nn.Linear(2, 3))
44+
DistributedCheckpointSaver.restore(path="path/to/snapshot", unit=my_unit)
45+
process_group = mock_dist_cp.load.call_args.kwargs["process_group"]
46+
tc.assertEqual(dist.get_backend(process_group), dist.Backend.GLOO, None)
47+
mock_destroy_process_group.assert_called_once()
48+
2549
@skip_if_not_distributed
2650
@skip_if_not_gpu
2751
def test_save_restore_fsdp(self) -> None:

torchtnt/framework/callbacks/dcp_saver.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import logging
1010
import time
1111
from concurrent.futures import Future
12+
from datetime import timedelta
1213
from typing import Any, Dict, Iterable, List, Optional, Union
1314

1415
import torch
@@ -273,19 +274,39 @@ def restore(
273274
) -> None:
274275
"""Utility method to restore dcp checkpoint from a path."""
275276

277+
# use gloo pg if available
278+
gloo_pg_created = False
279+
if dist.is_initialized():
280+
pg = dist.group.WORLD if process_group is None else process_group
281+
282+
if dist.get_backend(pg) != dist.Backend.GLOO:
283+
rank_zero_info(
284+
"Creating new gloo process group for loading checkpoint."
285+
)
286+
pg = dist.new_group(
287+
timeout=timedelta(seconds=3600), backend=dist.Backend.GLOO
288+
)
289+
gloo_pg_created = True
290+
else:
291+
pg = process_group
292+
276293
checkpoint_id = path
277294

278295
DistributedCheckpointSaver.restore_with_id(
279296
checkpoint_id,
280297
unit,
281298
train_dataloader=train_dataloader,
282-
process_group=process_group,
299+
process_group=pg,
283300
restore_options=restore_options,
284301
knob_options=knob_options,
285302
planner=planner,
286303
storage_reader=storage_reader,
287304
)
288305

306+
# destroy gloo pg if created, its sole purpose was for checkpoint restore
307+
if gloo_pg_created:
308+
dist.destroy_process_group(pg)
309+
289310
@staticmethod
290311
def restore_with_id(
291312
checkpoint_id: Union[int, str],

0 commit comments

Comments
 (0)