Skip to content

Commit e7b9e64

Browse files
galrotemfacebook-github-bot
authored andcommitted
move rank_zero_read_and_broadcast to distributed utils (#796)
Summary: Pull Request resolved: #796 Reviewed By: diego-urgell, JKSenthil Differential Revision: D56506784 fbshipit-source-id: 331450f67abe2a60653b546a9d3bf60045daaf2a
1 parent e6739ab commit e7b9e64

File tree

5 files changed

+61
-73
lines changed

5 files changed

+61
-73
lines changed

tests/framework/callbacks/test_checkpoint_utils.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -421,23 +421,3 @@ def test_get_app_state(self) -> None:
421421
app_state.keys(),
422422
["module", "optimizer", "loss_fn", "train_progress"],
423423
)
424-
425-
@skip_if_not_distributed
426-
def test_rank_zero_read_and_broadcast(self) -> None:
427-
spawn_multi_process(2, "gloo", self._test_rank_zero_read_and_broadcast)
428-
429-
@staticmethod
430-
def _test_rank_zero_read_and_broadcast() -> None:
431-
"""
432-
Tests that rank_zero_read_and_broadcast decorator works as expected
433-
"""
434-
435-
@rank_zero_read_and_broadcast
436-
def _test_method_for_rank_zero() -> str:
437-
assert get_global_rank() == 0
438-
return "foo"
439-
440-
init_from_env()
441-
val_from_test_method = _test_method_for_rank_zero()
442-
tc = unittest.TestCase()
443-
tc.assertEqual(val_from_test_method, "foo")

tests/utils/test_distributed.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
get_world_size,
3131
PGWrapper,
3232
rank_zero_fn,
33+
rank_zero_read_and_broadcast,
3334
revert_sync_batchnorm,
3435
spawn_multi_process,
3536
sync_bool,
@@ -443,3 +444,22 @@ def _test_method(offset_arg: int, offset_kwarg: int) -> int:
443444
def test_spawn_multi_process(self) -> None:
444445
mp_list = spawn_multi_process(2, "gloo", self._test_method, 3, offset_kwarg=2)
445446
self.assertEqual(mp_list, [1, 2])
447+
448+
@skip_if_not_distributed
449+
def test_rank_zero_read_and_broadcast(self) -> None:
450+
spawn_multi_process(2, "gloo", self._test_rank_zero_read_and_broadcast)
451+
452+
@staticmethod
453+
def _test_rank_zero_read_and_broadcast() -> None:
454+
"""
455+
Tests that rank_zero_read_and_broadcast decorator works as expected
456+
"""
457+
458+
@rank_zero_read_and_broadcast
459+
def _test_method_for_rank_zero() -> str:
460+
assert get_global_rank() == 0
461+
return "foo"
462+
463+
val_from_test_method = _test_method_for_rank_zero()
464+
tc = unittest.TestCase()
465+
tc.assertEqual(val_from_test_method, "foo")

torchtnt/framework/callbacks/_checkpoint_utils.py

Lines changed: 2 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,7 @@
1010
import os
1111
import re
1212

13-
from typing import (
14-
Any,
15-
Callable,
16-
cast,
17-
Dict,
18-
List,
19-
Literal,
20-
Optional,
21-
Pattern,
22-
Tuple,
23-
TypeVar,
24-
)
13+
from typing import Any, Dict, List, Literal, Optional, Pattern, Tuple, TypeVar
2514

2615
import fsspec
2716

@@ -30,7 +19,7 @@
3019
from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions
3120
from torchtnt.framework.state import State
3221
from torchtnt.framework.unit import AppStateMixin
33-
from torchtnt.utils.distributed import get_global_rank, PGWrapper
22+
from torchtnt.utils.distributed import rank_zero_read_and_broadcast
3423

3524
from torchtnt.utils.fsspec import get_filesystem
3625
from torchtnt.utils.stateful import Stateful
@@ -40,44 +29,6 @@
4029
T = TypeVar("T")
4130

4231

43-
def rank_zero_read_and_broadcast(
44-
func: Callable[..., T],
45-
) -> Callable[..., T]:
46-
"""
47-
Decorator that ensures a function is only executed by rank 0 and returns the result to all ranks.
48-
49-
Note:
50-
By default will use the global process group. To use a custom process group, `process_group` must be an arg to the function and passed as a keyword argument.
51-
"""
52-
53-
def wrapper(*args: Any, **kwargs: Any) -> T:
54-
ret = None
55-
rank = get_global_rank()
56-
process_group = kwargs.pop("process_group", None)
57-
58-
# Do all filesystem reads from rank 0 only
59-
if rank == 0:
60-
ret = func(*args, **kwargs)
61-
62-
# If not running in a distributed setting, return as is
63-
if not (dist.is_available() and dist.is_initialized()):
64-
# we cast here to avoid type errors, since it is
65-
# guaranteed the return value is of type T
66-
return cast(T, ret)
67-
68-
# Otherwise, broadcast result from rank 0 to all ranks
69-
pg = PGWrapper(process_group)
70-
path_container = [ret]
71-
pg.broadcast_object_list(path_container, 0)
72-
val = path_container[0]
73-
74-
# we cast here to avoid type errors, since it is
75-
# guaranteed the return value is of type T
76-
return cast(T, val)
77-
78-
return wrapper
79-
80-
8132
@rank_zero_read_and_broadcast
8233
def get_latest_checkpoint_path(
8334
dirpath: str,

torchtnt/framework/callbacks/base_checkpointer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
get_best_checkpoint_path,
2525
get_checkpoint_dirpaths,
2626
get_latest_checkpoint_path,
27-
rank_zero_read_and_broadcast,
2827
)
2928
from torchtnt.framework.callbacks.checkpointer_types import (
3029
BestCheckpointConfig,
@@ -33,7 +32,7 @@
3332
from torchtnt.framework.state import EntryPoint, State
3433
from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TTrainData, TTrainUnit
3534
from torchtnt.framework.utils import get_timing_context
36-
from torchtnt.utils.distributed import PGWrapper
35+
from torchtnt.utils.distributed import PGWrapper, rank_zero_read_and_broadcast
3736
from torchtnt.utils.fsspec import get_filesystem
3837
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
3938

torchtnt/utils/distributed.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,3 +590,41 @@ def _init_pg_and_rank_and_launch_method(
590590

591591
finally:
592592
destroy_process_group()
593+
594+
595+
def rank_zero_read_and_broadcast(
596+
func: Callable[..., T],
597+
) -> Callable[..., T]:
598+
"""
599+
Decorator that ensures a function is only executed by rank 0 and returns the result to all ranks.
600+
601+
Note:
602+
By default will use the global process group. To use a custom process group, `process_group` must be an arg to the function and passed as a keyword argument.
603+
"""
604+
605+
def wrapper(*args: Any, **kwargs: Any) -> T:
606+
ret = None
607+
rank = get_global_rank()
608+
process_group = kwargs.pop("process_group", None)
609+
610+
# Do all filesystem reads from rank 0 only
611+
if rank == 0:
612+
ret = func(*args, **kwargs)
613+
614+
# If not running in a distributed setting, return as is
615+
if not (dist.is_available() and dist.is_initialized()):
616+
# we cast here to avoid type errors, since it is
617+
# guaranteed the return value is of type T
618+
return cast(T, ret)
619+
620+
# Otherwise, broadcast result from rank 0 to all ranks
621+
pg = PGWrapper(process_group)
622+
path_container = [ret]
623+
pg.broadcast_object_list(path_container, 0)
624+
val = path_container[0]
625+
626+
# we cast here to avoid type errors, since it is
627+
# guaranteed the return value is of type T
628+
return cast(T, val)
629+
630+
return wrapper

0 commit comments

Comments
 (0)