Skip to content

Commit 882832c

Browse files
galrotemfacebook-github-bot
authored andcommitted
add eval support to distributed sync (#847)
Summary: Pull Request resolved: #847 Support also eval Reviewed By: diego-urgell Differential Revision: D58855082 fbshipit-source-id: b2ed61e27d1d8a786ba3352f14437b48097a2e4d
1 parent 0f72333 commit 882832c

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

tests/framework/callbacks/test_periodic_distributed_sync.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import unittest
1111
from unittest.mock import MagicMock, patch
1212

13-
from torchtnt.framework._test_utils import DummyPredictUnit
13+
from torchtnt.framework._test_utils import DummyEvalUnit, DummyPredictUnit
1414

1515
from torchtnt.framework.callbacks.periodic_distributed_sync import (
1616
PeriodicDistributedSync,
@@ -20,7 +20,7 @@
2020

2121
class PeriodicDistributedSyncTest(unittest.TestCase):
2222
@patch("torchtnt.framework.callbacks.periodic_distributed_sync.barrier")
23-
def test_frequency(self, barrier_mock: MagicMock) -> None:
23+
def test_frequency_predict(self, barrier_mock: MagicMock) -> None:
2424
pds = PeriodicDistributedSync(sync_every_n_steps=2)
2525
unit = DummyPredictUnit(2)
2626
state = State(entry_point=EntryPoint.PREDICT)
@@ -31,3 +31,16 @@ def test_frequency(self, barrier_mock: MagicMock) -> None:
3131
unit.predict_progress.increment_step() # 2 steps completed
3232
pds.on_predict_step_end(state, unit)
3333
barrier_mock.assert_called_once()
34+
35+
@patch("torchtnt.framework.callbacks.periodic_distributed_sync.barrier")
36+
def test_frequency_evaluate(self, barrier_mock: MagicMock) -> None:
37+
pds = PeriodicDistributedSync(sync_every_n_steps=2)
38+
unit = DummyEvalUnit(2)
39+
state = State(entry_point=EntryPoint.EVALUATE)
40+
unit.eval_progress.increment_step() # 1 step completed
41+
pds.on_eval_step_end(state, unit)
42+
barrier_mock.assert_not_called()
43+
44+
unit.eval_progress.increment_step() # 2 steps completed
45+
pds.on_eval_step_end(state, unit)
46+
barrier_mock.assert_called_once()

torchtnt/framework/callbacks/periodic_distributed_sync.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
from torchtnt.framework.callback import Callback
1212
from torchtnt.framework.state import State
13-
from torchtnt.framework.unit import TPredictUnit
14-
from torchtnt.utils.distributed import barrier
13+
from torchtnt.framework.unit import TEvalUnit, TPredictUnit
14+
from torchtnt.utils.distributed import barrier, get_global_rank
1515

1616
logger: logging.Logger = logging.getLogger(__name__)
1717

@@ -20,17 +20,24 @@ class PeriodicDistributedSync(Callback):
2020
"""
2121
A callback to sync all distributed workers at a given frequency.
2222
Helpful when using distributed without DDP/FSDP but would still like to ensure that the workers are in sync with each other, for example large predict jobs.
23-
Note that only predict is supported at the moment.
23+
Both predict and evaluate are supported.
2424
2525
Args:
2626
sync_every_n_steps: the frequency at which to sync the workers.
2727
"""
2828

2929
def __init__(self, sync_every_n_steps: int = 1000) -> None:
3030
self.sync_every_n_steps = sync_every_n_steps
31+
self._global_rank: int = get_global_rank()
3132

3233
def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None:
3334
num_steps = unit.predict_progress.num_steps_completed
3435
if num_steps % self.sync_every_n_steps == 0:
35-
logger.info(f"Barrier at step {num_steps}")
36+
logger.info(f"Barrier at step {num_steps} on rank {self._global_rank}")
37+
barrier()
38+
39+
def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None:
40+
num_steps = unit.eval_progress.num_steps_completed
41+
if num_steps % self.sync_every_n_steps == 0:
42+
logger.info(f"Barrier at step {num_steps} on rank {self._global_rank}")
3643
barrier()

0 commit comments

Comments
 (0)