Skip to content

Commit a8750ae

Browse files
galrotemfacebook-github-bot
authored andcommitted
periodic distributed sync (#843)
Summary: Pull Request resolved: #843 Reviewed By: diego-urgell Differential Revision: D58118753 fbshipit-source-id: 42d69c285ca36738a86020018b4137c3a9d20e1d
1 parent b21deb6 commit a8750ae

File tree

4 files changed

+72
-0
lines changed

4 files changed

+72
-0
lines changed

docs/source/framework/callbacks.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ We offer several pre-written callbacks which are ready to be used out of the box
2727
LearningRateMonitor
2828
MemorySnapshot
2929
ModuleSummary
30+
PeriodicDistributedSync
3031
ProgressReporter
3132
PyTorchProfiler
3233
SlowRankDetector
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import unittest
11+
from unittest.mock import MagicMock, patch
12+
13+
from torchtnt.framework._test_utils import DummyPredictUnit
14+
15+
from torchtnt.framework.callbacks.periodic_distributed_sync import (
16+
PeriodicDistributedSync,
17+
)
18+
from torchtnt.framework.state import EntryPoint, State
19+
20+
21+
class PeriodicDistributedSyncTest(unittest.TestCase):
22+
@patch("torchtnt.framework.callbacks.periodic_distributed_sync.barrier")
23+
def test_frequency(self, barrier_mock: MagicMock) -> None:
24+
pds = PeriodicDistributedSync(sync_every_n_steps=2)
25+
unit = DummyPredictUnit(2)
26+
state = State(entry_point=EntryPoint.PREDICT)
27+
unit.predict_progress.increment_step() # 1 step completed
28+
pds.on_predict_step_end(state, unit)
29+
barrier_mock.assert_not_called()
30+
31+
unit.predict_progress.increment_step() # 2 steps completed
32+
pds.on_predict_step_end(state, unit)
33+
barrier_mock.assert_called_once()

torchtnt/framework/callbacks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .learning_rate_monitor import LearningRateMonitor
1717
from .memory_snapshot import MemorySnapshot
1818
from .module_summary import ModuleSummary
19+
from .periodic_distributed_sync import PeriodicDistributedSync
1920
from .progress_reporter import ProgressReporter
2021
from .pytorch_profiler import PyTorchProfiler
2122
from .slow_rank_detector import SlowRankDetector
@@ -39,6 +40,7 @@
3940
"LearningRateMonitor",
4041
"MemorySnapshot",
4142
"ModuleSummary",
43+
"PeriodicDistributedSync",
4244
"ProgressReporter",
4345
"PyTorchProfiler",
4446
"SlowRankDetector",
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import logging
10+
11+
from torchtnt.framework.callback import Callback
12+
from torchtnt.framework.state import State
13+
from torchtnt.framework.unit import TPredictUnit
14+
from torchtnt.utils.distributed import barrier
15+
16+
logger: logging.Logger = logging.getLogger(__name__)
17+
18+
19+
class PeriodicDistributedSync(Callback):
20+
"""
21+
A callback to sync all distributed workers at a given frequency.
22+
Helpful when using distributed without DDP/FSDP but would still like to ensure that the workers are in sync with each other, for example large predict jobs.
23+
Note that only predict is supported at the moment.
24+
25+
Args:
26+
sync_every_n_steps: the frequency at which to sync the workers.
27+
"""
28+
29+
def __init__(self, sync_every_n_steps: int = 1000) -> None:
30+
self.sync_every_n_steps = sync_every_n_steps
31+
32+
def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None:
33+
num_steps = unit.predict_progress.num_steps_completed
34+
if num_steps % self.sync_every_n_steps == 0:
35+
logger.info(f"Barrier at step {num_steps}")
36+
barrier()

0 commit comments

Comments
 (0)