Skip to content

Commit d86828b

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Swap DCP restore ad-hoc gloo pg creation with context manager (#903)
Summary: Pull Request resolved: #903 Reviewed By: JKSenthil Differential Revision: D63268179 fbshipit-source-id: 47294a0fd8c560d7abb3db426305ec8b522c0432
1 parent 3377801 commit d86828b

File tree

1 file changed

+9
-29
lines changed

1 file changed

+9
-29
lines changed

torchtnt/framework/callbacks/dcp_saver.py

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import logging
1010
import time
1111
from concurrent.futures import Future
12-
from datetime import timedelta
1312
from typing import Any, Dict, Iterable, List, Optional, Union
1413

1514
import torch.distributed as dist
@@ -45,6 +44,7 @@
4544
)
4645
from torchtnt.framework.utils import get_timing_context
4746
from torchtnt.utils.checkpoint import BestCheckpointConfig
47+
from torchtnt.utils.distributed import get_or_create_gloo_pg
4848
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
4949
from torchtnt.utils.stateful import MultiStateful, Stateful
5050

@@ -271,23 +271,6 @@ def restore_with_id(
271271
storage_reader: Instance of StorageReader used to perform reads. If this is not specified, it will automatically infer
272272
the reader based on the checkpoint_id. If checkpoint_id is also None, an exception will be raised. (Default: ``None``)
273273
"""
274-
275-
# use gloo pg if available
276-
gloo_pg_created = False
277-
if dist.is_initialized():
278-
pg = dist.group.WORLD if process_group is None else process_group
279-
280-
if dist.get_backend(pg) != dist.Backend.GLOO:
281-
rank_zero_info(
282-
"Creating new gloo process group for loading checkpoint."
283-
)
284-
pg = dist.new_group(
285-
timeout=timedelta(seconds=3600), backend=dist.Backend.GLOO
286-
)
287-
gloo_pg_created = True
288-
else:
289-
pg = process_group
290-
291274
restore_options = restore_options or RestoreOptions()
292275
app_state = _prepare_app_state_for_restore(unit, restore_options)
293276
checkpoint_id = str(checkpoint_id)
@@ -321,22 +304,19 @@ def restore_with_id(
321304
"train_dataloader was passed to `restore` but no train dataloader exists in the Snapshot"
322305
)
323306

324-
dcp.load(
325-
{"app_state": MultiStateful(app_state)},
326-
checkpoint_id=checkpoint_id,
327-
storage_reader=storage_reader,
328-
planner=planner,
329-
process_group=pg,
330-
)
307+
with get_or_create_gloo_pg(candidate_pg=process_group) as pg:
308+
dcp.load(
309+
{"app_state": MultiStateful(app_state)},
310+
checkpoint_id=checkpoint_id,
311+
storage_reader=storage_reader,
312+
planner=planner,
313+
process_group=pg,
314+
)
331315

332316
rank_zero_info(
333317
f"Restored snapshot for checkpoint_id: {checkpoint_id}", logger=logger
334318
)
335319

336-
# destroy gloo pg if created, its sole purpose was for checkpoint restore
337-
if gloo_pg_created:
338-
dist.destroy_process_group(pg)
339-
340320
def _generate_checkpoint_and_upkeep(
341321
self, state: State, unit: Union[TTrainUnit, TEvalUnit], hook: str
342322
) -> bool:

0 commit comments

Comments
 (0)