|
9 | 9 | import logging
|
10 | 10 | import time
|
11 | 11 | from concurrent.futures import Future
|
12 |
| -from datetime import timedelta |
13 | 12 | from typing import Any, Dict, Iterable, List, Optional, Union
|
14 | 13 |
|
15 | 14 | import torch.distributed as dist
|
|
45 | 44 | )
|
46 | 45 | from torchtnt.framework.utils import get_timing_context
|
47 | 46 | from torchtnt.utils.checkpoint import BestCheckpointConfig
|
| 47 | +from torchtnt.utils.distributed import get_or_create_gloo_pg |
48 | 48 | from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
|
49 | 49 | from torchtnt.utils.stateful import MultiStateful, Stateful
|
50 | 50 |
|
@@ -271,23 +271,6 @@ def restore_with_id(
|
271 | 271 | storage_reader: Instance of StorageReader used to perform reads. If this is not specified, it will automatically infer
|
272 | 272 | the reader based on the checkpoint_id. If checkpoint_id is also None, an exception will be raised. (Default: ``None``)
|
273 | 273 | """
|
274 |
| - |
275 |
| - # use gloo pg if available |
276 |
| - gloo_pg_created = False |
277 |
| - if dist.is_initialized(): |
278 |
| - pg = dist.group.WORLD if process_group is None else process_group |
279 |
| - |
280 |
| - if dist.get_backend(pg) != dist.Backend.GLOO: |
281 |
| - rank_zero_info( |
282 |
| - "Creating new gloo process group for loading checkpoint." |
283 |
| - ) |
284 |
| - pg = dist.new_group( |
285 |
| - timeout=timedelta(seconds=3600), backend=dist.Backend.GLOO |
286 |
| - ) |
287 |
| - gloo_pg_created = True |
288 |
| - else: |
289 |
| - pg = process_group |
290 |
| - |
291 | 274 | restore_options = restore_options or RestoreOptions()
|
292 | 275 | app_state = _prepare_app_state_for_restore(unit, restore_options)
|
293 | 276 | checkpoint_id = str(checkpoint_id)
|
@@ -321,22 +304,19 @@ def restore_with_id(
|
321 | 304 | "train_dataloader was passed to `restore` but no train dataloader exists in the Snapshot"
|
322 | 305 | )
|
323 | 306 |
|
324 |
| - dcp.load( |
325 |
| - {"app_state": MultiStateful(app_state)}, |
326 |
| - checkpoint_id=checkpoint_id, |
327 |
| - storage_reader=storage_reader, |
328 |
| - planner=planner, |
329 |
| - process_group=pg, |
330 |
| - ) |
| 307 | + with get_or_create_gloo_pg(candidate_pg=process_group) as pg: |
| 308 | + dcp.load( |
| 309 | + {"app_state": MultiStateful(app_state)}, |
| 310 | + checkpoint_id=checkpoint_id, |
| 311 | + storage_reader=storage_reader, |
| 312 | + planner=planner, |
| 313 | + process_group=pg, |
| 314 | + ) |
331 | 315 |
|
332 | 316 | rank_zero_info(
|
333 | 317 | f"Restored snapshot for checkpoint_id: {checkpoint_id}", logger=logger
|
334 | 318 | )
|
335 | 319 |
|
336 |
| - # destroy gloo pg if created, its sole purpose was for checkpoint restore |
337 |
| - if gloo_pg_created: |
338 |
| - dist.destroy_process_group(pg) |
339 |
| - |
340 | 320 | def _generate_checkpoint_and_upkeep(
|
341 | 321 | self, state: State, unit: Union[TTrainUnit, TEvalUnit], hook: str
|
342 | 322 | ) -> bool:
|
|
0 commit comments