Skip to content

Commit 6de95a5

Browse files
galrotemfacebook-github-bot
authored andcommitted
add progress reporter callback (#785)
Summary: Pull Request resolved: #785 Reviewed By: JKSenthil Differential Revision: D56175728 fbshipit-source-id: be61bf67dd0b0ac18d3633574ac7f91259e08432
1 parent 5beb537 commit 6de95a5

File tree

4 files changed

+154
-0
lines changed

4 files changed

+154
-0
lines changed

docs/source/framework/callbacks.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ We offer several pre-written callbacks which are ready to be used out of the box
2626
LearningRateMonitor
2727
MemorySnapshot
2828
ModuleSummary
29+
ProgressReporter
2930
PyTorchProfiler
3031
SlowRankDetector
3132
SystemResourcesMonitor
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import unittest
11+
12+
import torch
13+
from torchtnt.framework._test_utils import DummyAutoUnit
14+
from torchtnt.framework.callbacks.progress_reporter import ProgressReporter
15+
from torchtnt.framework.state import EntryPoint, State
16+
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
17+
from torchtnt.utils.progress import Progress
18+
19+
20+
class ProgressReporterTest(unittest.TestCase):
21+
def test_log_with_rank(self) -> None:
22+
spawn_multi_process(2, "gloo", self._test_log_with_rank)
23+
24+
@staticmethod
25+
def _test_log_with_rank() -> None:
26+
progress_reporter = ProgressReporter()
27+
unit = DummyAutoUnit(module=torch.nn.Linear(2, 2))
28+
unit.train_progress = Progress(
29+
num_epochs_completed=1,
30+
num_steps_completed=5,
31+
num_steps_completed_in_epoch=3,
32+
)
33+
unit.eval_progress = Progress(
34+
num_epochs_completed=2,
35+
num_steps_completed=15,
36+
num_steps_completed_in_epoch=7,
37+
)
38+
state = State(entry_point=EntryPoint.FIT)
39+
tc = unittest.TestCase()
40+
with tc.assertLogs(level="INFO") as log:
41+
progress_reporter.on_train_end(state, unit)
42+
tc.assertEqual(
43+
log.output,
44+
[
45+
f"INFO:torchtnt.framework.callbacks.progress_reporter:Progress Reporter: rank {get_global_rank()} at on_train_end. "
46+
"Train progress: completed epochs: 1, completed steps: 5, completed steps in current epoch: 3. "
47+
"Eval progress: completed epochs: 2, completed steps: 15, completed steps in current epoch: 7."
48+
],
49+
)

torchtnt/framework/callbacks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .learning_rate_monitor import LearningRateMonitor
1717
from .memory_snapshot import MemorySnapshot
1818
from .module_summary import ModuleSummary
19+
from .progress_reporter import ProgressReporter
1920
from .pytorch_profiler import PyTorchProfiler
2021
from .slow_rank_detector import SlowRankDetector
2122
from .system_resources_monitor import SystemResourcesMonitor
@@ -36,6 +37,7 @@
3637
"LearningRateMonitor",
3738
"MemorySnapshot",
3839
"ModuleSummary",
40+
"ProgressReporter",
3941
"PyTorchProfiler",
4042
"SlowRankDetector",
4143
"SystemResourcesMonitor",
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import logging
9+
from typing import cast
10+
11+
from torchtnt.framework.callback import Callback
12+
from torchtnt.framework.state import EntryPoint, State
13+
from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TPredictUnit, TTrainUnit
14+
from torchtnt.utils.distributed import get_global_rank
15+
16+
logger: logging.Logger = logging.getLogger(__name__)
17+
18+
19+
class ProgressReporter(Callback):
20+
"""
21+
A simple callback which logs the progress at each loop start/end, epoch start/end and step start/end.
22+
This is useful to debug certain issues, for which the root cause might be unequal progress across ranks, for instance NCCL timeouts.
23+
If used, it's recommended to pass this callback as the first item in the callbacks list.
24+
"""
25+
26+
def on_train_start(self, state: State, unit: TTrainUnit) -> None:
27+
self._log_with_rank_and_unit(state, unit, "on_train_start")
28+
29+
def on_train_epoch_start(self, state: State, unit: TTrainUnit) -> None:
30+
self._log_with_rank_and_unit(state, unit, "on_train_epoch_start")
31+
32+
def on_train_step_start(self, state: State, unit: TTrainUnit) -> None:
33+
self._log_with_rank_and_unit(state, unit, "on_train_step_start")
34+
35+
def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
36+
self._log_with_rank_and_unit(state, unit, "on_train_step_end")
37+
38+
def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None:
39+
self._log_with_rank_and_unit(state, unit, "on_train_epoch_end")
40+
41+
def on_train_end(self, state: State, unit: TTrainUnit) -> None:
42+
self._log_with_rank_and_unit(state, unit, "on_train_end")
43+
44+
def on_eval_start(self, state: State, unit: TEvalUnit) -> None:
45+
self._log_with_rank_and_unit(state, unit, "on_eval_start")
46+
47+
def on_eval_epoch_start(self, state: State, unit: TEvalUnit) -> None:
48+
self._log_with_rank_and_unit(state, unit, "on_eval_epoch_start")
49+
50+
def on_eval_step_start(self, state: State, unit: TEvalUnit) -> None:
51+
self._log_with_rank_and_unit(state, unit, "on_eval_step_start")
52+
53+
def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None:
54+
self._log_with_rank_and_unit(state, unit, "on_eval_step_end")
55+
56+
def on_eval_epoch_end(self, state: State, unit: TEvalUnit) -> None:
57+
self._log_with_rank_and_unit(state, unit, "on_eval_epoch_end")
58+
59+
def on_eval_end(self, state: State, unit: TEvalUnit) -> None:
60+
self._log_with_rank_and_unit(state, unit, "on_eval_end")
61+
62+
def on_predict_start(self, state: State, unit: TPredictUnit) -> None:
63+
self._log_with_rank_and_unit(state, unit, "on_predict_start")
64+
65+
def on_predict_epoch_start(self, state: State, unit: TPredictUnit) -> None:
66+
self._log_with_rank_and_unit(state, unit, "on_predict_epoch_start")
67+
68+
def on_predict_step_start(self, state: State, unit: TPredictUnit) -> None:
69+
self._log_with_rank_and_unit(state, unit, "on_predict_step_start")
70+
71+
def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None:
72+
self._log_with_rank_and_unit(state, unit, "on_predict_step_end")
73+
74+
def on_predict_epoch_end(self, state: State, unit: TPredictUnit) -> None:
75+
self._log_with_rank_and_unit(state, unit, "on_predict_epoch_end")
76+
77+
def on_predict_end(self, state: State, unit: TPredictUnit) -> None:
78+
self._log_with_rank_and_unit(state, unit, "on_predict_end")
79+
80+
@classmethod
81+
def _log_with_rank_and_unit(
82+
cls, state: State, unit: AppStateMixin, hook: str
83+
) -> None:
84+
output_str = f"Progress Reporter: rank {get_global_rank()} at {hook}."
85+
if state.entry_point == EntryPoint.TRAIN:
86+
output_str = f"{output_str} Train progress: {cast(TTrainUnit, unit).train_progress.get_progress_string()}"
87+
88+
elif state.entry_point == EntryPoint.EVALUATE:
89+
output_str = f"{output_str} Eval progress: {cast(TEvalUnit, unit).eval_progress.get_progress_string()}"
90+
91+
elif state.entry_point == EntryPoint.PREDICT:
92+
output_str = f"{output_str} Predict progress: {cast(TPredictUnit, unit).predict_progress.get_progress_string()}"
93+
94+
elif state.entry_point == EntryPoint.FIT:
95+
output_str = f"{output_str} Train progress: {cast(TTrainUnit, unit).train_progress.get_progress_string()} Eval progress: {cast(TEvalUnit, unit).eval_progress.get_progress_string()}"
96+
97+
else:
98+
raise ValueError(
99+
f"State entry point {state.entry_point} is not supported in ProgressReporter"
100+
)
101+
102+
logger.info(output_str)

0 commit comments

Comments
 (0)