Skip to content

Commit c988b86

Browse files
galrotemfacebook-github-bot
authored andcommitted
slow rank detector callback (#764)
Summary: Pull Request resolved: #764 Reviewed By: JKSenthil Differential Revision: D55383095 fbshipit-source-id: fa42d0cf664c78016c9c95634788cca1d001a3cd
1 parent 244a3e9 commit c988b86

File tree

4 files changed

+258
-0
lines changed

4 files changed

+258
-0
lines changed

docs/source/framework/callbacks.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ We offer several pre-written callbacks which are ready to be used out of the box
2727
MemorySnapshot
2828
ModuleSummary
2929
PyTorchProfiler
30+
SlowRankDetector
3031
SystemResourcesMonitor
3132
TensorBoardParameterMonitor
3233
TimeLimitInterrupter
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import unittest
11+
from unittest import mock
12+
from unittest.mock import MagicMock
13+
14+
from torchtnt.framework.callbacks.slow_rank_detector import (
15+
_get_min_max_indices,
16+
SlowRankDetector,
17+
)
18+
19+
from torchtnt.framework.state import State
20+
from torchtnt.framework.unit import TrainUnit
21+
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
22+
from torchtnt.utils.loggers.logger import MetricLogger
23+
from torchtnt.utils.progress import Progress
24+
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu
25+
26+
27+
class SlowRankDetectorTest(unittest.TestCase):
28+
29+
@skip_if_not_distributed
30+
@skip_if_not_gpu
31+
def test_sync_times(self) -> None:
32+
spawn_multi_process(2, "nccl", self._test_sync_times)
33+
34+
@staticmethod
35+
def _test_sync_times() -> None:
36+
tc = unittest.TestCase()
37+
rank = get_global_rank()
38+
logger = MagicMock(spec=MetricLogger)
39+
40+
with mock.patch("time.perf_counter", return_value=rank + 1), tc.assertLogs(
41+
level="INFO"
42+
) as log:
43+
slow_rank_detector = SlowRankDetector(logger=logger)
44+
slow_rank_detector._sync_times(1, 1)
45+
tc.assertEqual(
46+
log.output,
47+
[
48+
"INFO:torchtnt.framework.callbacks.slow_rank_detector:Time difference between fastest rank (0: 1.0 sec) and slowest rank (1: 2.0 sec) is 1.0 seconds after 1 epochs and 1 steps."
49+
],
50+
)
51+
if rank == 0:
52+
logger.log.assert_called_once_with(
53+
"Difference between fastest/slowest rank (seconds)", 1.0, 1
54+
)
55+
else:
56+
logger.log.assert_not_called()
57+
58+
def test_get_min_max_indices(self) -> None:
59+
min_index, max_index = _get_min_max_indices([5.0, 2.0, 3.5])
60+
self.assertEqual(min_index, 1)
61+
self.assertEqual(max_index, 0)
62+
63+
min_index, max_index = _get_min_max_indices([1.0])
64+
self.assertEqual(min_index, 0)
65+
self.assertEqual(max_index, 0)
66+
67+
min_index, max_index = _get_min_max_indices([2.0, 3.0, 2.0])
68+
self.assertEqual(min_index, 0)
69+
self.assertEqual(max_index, 1)
70+
71+
def test_invalid_initialization_params(self) -> None:
72+
with self.assertRaisesRegex(
73+
ValueError,
74+
"At least one of check_every_n_steps or check_every_n_epochs must be specified.",
75+
):
76+
SlowRankDetector(check_every_n_steps=None, check_every_n_epochs=None)
77+
78+
with self.assertRaisesRegex(
79+
ValueError,
80+
"check_every_n_steps must be a positive integer. Value passed is 0",
81+
):
82+
SlowRankDetector(check_every_n_steps=0)
83+
84+
with self.assertRaisesRegex(
85+
ValueError,
86+
"check_every_n_epochs must be a positive integer. Value passed is 0",
87+
):
88+
SlowRankDetector(check_every_n_epochs=0)
89+
90+
def test_sync_times_frequency(self) -> None:
91+
slow_rank_detector = SlowRankDetector(
92+
check_every_n_steps=2, check_every_n_epochs=2
93+
)
94+
unit = MagicMock(spec=TrainUnit)
95+
unit.train_progress = Progress(num_epochs_completed=1, num_steps_completed=1)
96+
state = MagicMock(spec=State)
97+
with mock.patch.object(slow_rank_detector, "_sync_times") as sync_times_mock:
98+
# first step shouldn't trigger time sync
99+
slow_rank_detector.on_train_step_end(state, unit)
100+
sync_times_mock.assert_not_called()
101+
102+
# second step should trigger time sync
103+
unit.train_progress.increment_step()
104+
slow_rank_detector.on_train_step_end(state, unit)
105+
sync_times_mock.assert_called_once()
106+
107+
# third step shouldn't trigger time sync
108+
unit.train_progress.increment_step()
109+
sync_times_mock.reset_mock()
110+
slow_rank_detector.on_train_step_end(state, unit)
111+
sync_times_mock.assert_not_called()
112+
113+
# first epoch shouldn't trigger time sync
114+
slow_rank_detector.on_train_epoch_end(state, unit)
115+
sync_times_mock.assert_not_called()
116+
117+
# second epoch should trigger time sync
118+
unit.train_progress.increment_epoch()
119+
slow_rank_detector.on_train_epoch_end(state, unit)
120+
sync_times_mock.assert_called_once()
121+
122+
# third epoch shouldn't trigger time sync
123+
unit.train_progress.increment_epoch()
124+
sync_times_mock.reset_mock()
125+
slow_rank_detector.on_train_epoch_end(state, unit)
126+
sync_times_mock.assert_not_called()

torchtnt/framework/callbacks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .memory_snapshot import MemorySnapshot
1717
from .module_summary import ModuleSummary
1818
from .pytorch_profiler import PyTorchProfiler
19+
from .slow_rank_detector import SlowRankDetector
1920
from .system_resources_monitor import SystemResourcesMonitor
2021
from .tensorboard_parameter_monitor import TensorBoardParameterMonitor
2122
from .time_limit_interrupter import TimeLimitInterrupter
@@ -35,6 +36,7 @@
3536
"MemorySnapshot",
3637
"ModuleSummary",
3738
"PyTorchProfiler",
39+
"SlowRankDetector",
3840
"SystemResourcesMonitor",
3941
"TensorBoardParameterMonitor",
4042
"TimeLimitInterrupter",
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
import time
9+
from typing import List, Optional, Tuple
10+
11+
import torch
12+
from torch import distributed as dist
13+
from torchtnt.framework.callback import Callback
14+
from torchtnt.framework.state import State
15+
from torchtnt.framework.unit import TTrainUnit
16+
from torchtnt.utils.distributed import all_gather_tensors, get_global_rank
17+
from torchtnt.utils.env import init_from_env
18+
from torchtnt.utils.loggers.logger import MetricLogger
19+
20+
logger: logging.Logger = logging.getLogger(__name__)
21+
22+
23+
class SlowRankDetector(Callback):
24+
"""
25+
A callback which detects slow ranks every N steps/epochs by comparing the time on each process.
26+
This is useful to debug ranks which are lagging behind and are likely to cause a NCCL timeout.
27+
If a logger is passed, the difference between the fastest rank and slowest rank is also reported.
28+
29+
Args:
30+
check_every_n_steps: frequency of steps to check for slow ranks.
31+
check_every_n_epochs: frequency of epochs to check for slow ranks.
32+
pg: the process group to use for all_gather_tensors. If None, the default process group will be used.
33+
logger: an optional logger to log time difference.
34+
device: the device that will be used to store the time as a tensor. If none, the device will be inferred from the environment.
35+
36+
Note:
37+
It is recommended to use this callback after you detect a timeout, and to make sure this callback runs before
38+
the logic triggering timeout (other callback, train_step, etc).
39+
"""
40+
41+
def __init__(
42+
self,
43+
*,
44+
check_every_n_steps: Optional[int] = 100,
45+
check_every_n_epochs: Optional[int] = 1,
46+
pg: Optional[dist.ProcessGroup] = None,
47+
logger: Optional[MetricLogger] = None,
48+
device: Optional[torch.device] = None,
49+
) -> None:
50+
if not (check_every_n_steps or check_every_n_epochs):
51+
raise ValueError(
52+
"At least one of check_every_n_steps or check_every_n_epochs must be specified."
53+
)
54+
55+
if check_every_n_steps is not None and check_every_n_steps <= 0:
56+
raise ValueError(
57+
f"check_every_n_steps must be a positive integer. Value passed is {check_every_n_steps}"
58+
)
59+
60+
if check_every_n_epochs is not None and check_every_n_epochs <= 0:
61+
raise ValueError(
62+
f"check_every_n_epochs must be a positive integer. Value passed is {check_every_n_epochs}"
63+
)
64+
65+
self._check_every_n_steps = check_every_n_steps
66+
self._check_every_n_epochs = check_every_n_epochs
67+
self._pg = pg
68+
self._logger = logger
69+
self._device: torch.device = device or init_from_env()
70+
self._rank: int = get_global_rank()
71+
72+
def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
73+
if (
74+
self._check_every_n_steps is not None
75+
and unit.train_progress.num_steps_completed % self._check_every_n_steps == 0
76+
):
77+
self._sync_times(
78+
unit.train_progress.num_epochs_completed,
79+
unit.train_progress.num_steps_completed,
80+
)
81+
82+
def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None:
83+
if (
84+
self._check_every_n_epochs is not None
85+
and unit.train_progress.num_epochs_completed % self._check_every_n_epochs
86+
== 0
87+
):
88+
self._sync_times(
89+
unit.train_progress.num_epochs_completed,
90+
unit.train_progress.num_steps_completed,
91+
)
92+
93+
def _sync_times(self, epochs: int, steps: int) -> None:
94+
curr_time = time.perf_counter()
95+
curr_time_tensor = torch.Tensor([curr_time]).to(self._device)
96+
timings_as_tensor_list = all_gather_tensors(curr_time_tensor, self._pg)
97+
timings_as_list: List[float] = [
98+
tensor.item() for tensor in timings_as_tensor_list
99+
]
100+
fastest_rank, slowest_rank = _get_min_max_indices(timings_as_list)
101+
time_on_fastest_rank = timings_as_list[fastest_rank]
102+
time_on_slowest_rank = timings_as_list[slowest_rank]
103+
time_difference = time_on_slowest_rank - time_on_fastest_rank
104+
logger.info(
105+
f"""Time difference between fastest rank ({fastest_rank}: {time_on_fastest_rank} sec) and slowest rank ({slowest_rank}: {time_on_slowest_rank} sec) is {time_difference} seconds after {epochs} epochs and {steps} steps."""
106+
)
107+
if self._logger and self._rank == 0:
108+
self._logger.log(
109+
"Difference between fastest/slowest rank (seconds)",
110+
time_difference,
111+
steps,
112+
)
113+
114+
115+
# instead of taking a dependency on numpy
116+
def _get_min_max_indices(input_list: List[float]) -> Tuple[int, int]:
117+
min_index = -1
118+
max_index = -1
119+
min_value = float("inf")
120+
max_value = float("-inf")
121+
for rank, curr_value in enumerate(input_list):
122+
if curr_value < min_value:
123+
min_value = curr_value
124+
min_index = rank
125+
if curr_value > max_value:
126+
max_value = curr_value
127+
max_index = rank
128+
129+
return min_index, max_index

0 commit comments

Comments
 (0)