Skip to content

Commit 3dfcb7d

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
add local_rank_zero_fn decorator (#982)
Summary: Pull Request resolved: #982 Add `local_rank_zero_fn` decorator for functions that should be run by one process per host Reviewed By: galrotem, anshulverma Differential Revision: D70935839 fbshipit-source-id: 4fb267966546c08c0786894a667a6dde678a8774
1 parent 055aa15 commit 3dfcb7d

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

tests/utils/test_distributed.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
get_process_group_backend_from_device,
3131
get_tcp_init_method,
3232
get_world_size,
33+
local_rank_zero_fn,
3334
PGWrapper,
3435
rank_zero_fn,
3536
rank_zero_read_and_broadcast,
@@ -169,6 +170,14 @@ def foo() -> int:
169170
x = foo()
170171
assert x is None
171172

173+
def test_local_rank_zero_fn(self) -> None:
174+
@local_rank_zero_fn
175+
def foo() -> int:
176+
return 1
177+
178+
x = foo()
179+
assert x == 1
180+
172181
def test_revert_sync_batchnorm(self) -> None:
173182
original_batchnorm = torch.nn.modules.batchnorm.BatchNorm1d(4)
174183

torchtnt/utils/distributed.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,41 @@ def wrapped_fn(*args: TParams.args, **kwargs: TParams.kwargs) -> Optional[TRetur
377377
return wrapped_fn
378378

379379

380+
def local_rank_zero_fn(
381+
fn: Callable[TParams, TReturn]
382+
) -> Callable[TParams, Optional[TReturn]]:
383+
"""Function that can be used as a decorator to enable a function to be called on local rank 0 only.
384+
385+
Note:
386+
This decorator should be used judiciously. it should never be used on functions that need synchronization.
387+
It should be used very carefully with functions that mutate local state as well
388+
389+
Example:
390+
391+
>>> from torchtnt.utilities.distributed import local_rank_zero_fn
392+
>>> @local_rank_zero_fn
393+
... def foo():
394+
... return 1
395+
...
396+
>>> x = foo() # x is 1 if local rank is 0 else x is None
397+
398+
Args:
399+
fn: the desired function to be executed on rank 0 only
400+
401+
Return:
402+
wrapped_fn: the wrapped function that executes only if the global rank is 0
403+
404+
"""
405+
406+
@wraps(fn)
407+
def wrapped_fn(*args: TParams.args, **kwargs: TParams.kwargs) -> Optional[TReturn]:
408+
if get_local_rank() == 0:
409+
return fn(*args, **kwargs)
410+
return None
411+
412+
return wrapped_fn
413+
414+
380415
class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm):
381416
"""
382417
The only difference between :class:`torch.nn.BatchNorm1d`, :class:`torch.nn.BatchNorm2d`,

0 commit comments

Comments
 (0)