File tree Expand file tree Collapse file tree 4 files changed +72
-0
lines changed
tests/framework/callbacks
torchtnt/framework/callbacks Expand file tree Collapse file tree 4 files changed +72
-0
lines changed Original file line number Diff line number Diff line change @@ -27,6 +27,7 @@ We offer several pre-written callbacks which are ready to be used out of the box
27
27
LearningRateMonitor
28
28
MemorySnapshot
29
29
ModuleSummary
30
+ PeriodicDistributedSync
30
31
ProgressReporter
31
32
PyTorchProfiler
32
33
SlowRankDetector
Original file line number Diff line number Diff line change
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ import unittest
11
+ from unittest .mock import MagicMock , patch
12
+
13
+ from torchtnt .framework ._test_utils import DummyPredictUnit
14
+
15
+ from torchtnt .framework .callbacks .periodic_distributed_sync import (
16
+ PeriodicDistributedSync ,
17
+ )
18
+ from torchtnt .framework .state import EntryPoint , State
19
+
20
+
21
+ class PeriodicDistributedSyncTest (unittest .TestCase ):
22
+ @patch ("torchtnt.framework.callbacks.periodic_distributed_sync.barrier" )
23
+ def test_frequency (self , barrier_mock : MagicMock ) -> None :
24
+ pds = PeriodicDistributedSync (sync_every_n_steps = 2 )
25
+ unit = DummyPredictUnit (2 )
26
+ state = State (entry_point = EntryPoint .PREDICT )
27
+ unit .predict_progress .increment_step () # 1 step completed
28
+ pds .on_predict_step_end (state , unit )
29
+ barrier_mock .assert_not_called ()
30
+
31
+ unit .predict_progress .increment_step () # 2 steps completed
32
+ pds .on_predict_step_end (state , unit )
33
+ barrier_mock .assert_called_once ()
Original file line number Diff line number Diff line change 16
16
from .learning_rate_monitor import LearningRateMonitor
17
17
from .memory_snapshot import MemorySnapshot
18
18
from .module_summary import ModuleSummary
19
+ from .periodic_distributed_sync import PeriodicDistributedSync
19
20
from .progress_reporter import ProgressReporter
20
21
from .pytorch_profiler import PyTorchProfiler
21
22
from .slow_rank_detector import SlowRankDetector
39
40
"LearningRateMonitor" ,
40
41
"MemorySnapshot" ,
41
42
"ModuleSummary" ,
43
+ "PeriodicDistributedSync" ,
42
44
"ProgressReporter" ,
43
45
"PyTorchProfiler" ,
44
46
"SlowRankDetector" ,
Original file line number Diff line number Diff line change
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ import logging
10
+
11
+ from torchtnt .framework .callback import Callback
12
+ from torchtnt .framework .state import State
13
+ from torchtnt .framework .unit import TPredictUnit
14
+ from torchtnt .utils .distributed import barrier
15
+
16
+ logger : logging .Logger = logging .getLogger (__name__ )
17
+
18
+
19
+ class PeriodicDistributedSync (Callback ):
20
+ """
21
+ A callback to sync all distributed workers at a given frequency.
22
+ Helpful when using distributed without DDP/FSDP but would still like to ensure that the workers are in sync with each other, for example large predict jobs.
23
+ Note that only predict is supported at the moment.
24
+
25
+ Args:
26
+ sync_every_n_steps: the frequency at which to sync the workers.
27
+ """
28
+
29
+ def __init__ (self , sync_every_n_steps : int = 1000 ) -> None :
30
+ self .sync_every_n_steps = sync_every_n_steps
31
+
32
+ def on_predict_step_end (self , state : State , unit : TPredictUnit ) -> None :
33
+ num_steps = unit .predict_progress .num_steps_completed
34
+ if num_steps % self .sync_every_n_steps == 0 :
35
+ logger .info (f"Barrier at step { num_steps } " )
36
+ barrier ()
You can’t perform that action at this time.
0 commit comments