Skip to content

Commit ba310cf

Browse files
galrotemfacebook-github-bot
authored andcommitted
add warmup steps to time wait for batch logger (#838)
Summary: Pull Request resolved: #838 Adding warmup steps to different performance loggers Reviewed By: diego-urgell Differential Revision: D57595989 fbshipit-source-id: 496edd9c3c3f92a2454eb9ae9c9e3bf7496d670c
1 parent 41b9918 commit ba310cf

File tree

2 files changed

+34
-6
lines changed

2 files changed

+34
-6
lines changed

tests/framework/callbacks/test_time_wait_for_batch_logger.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from torchtnt.framework.state import EntryPoint, PhaseState, State
2727
from torchtnt.framework.train import _train_impl
2828
from torchtnt.utils.loggers.logger import MetricLogger
29-
from torchtnt.utils.timer import TimerProtocol
29+
from torchtnt.utils.timer import Timer, TimerProtocol
3030

3131

3232
class TimeWaitForBatchLoggerTest(unittest.TestCase):
@@ -119,10 +119,28 @@ def test_with_predict(self) -> None:
119119
],
120120
)
121121

122-
def test_invalid_log_every_n_steps(self) -> None:
122+
def test_warmup_steps(self) -> None:
123+
logger = MagicMock(spec=MetricLogger)
124+
callback = TimeWaitForBatchLogger(logger=logger, warmup_steps=1)
125+
timer = Timer()
126+
timer.recorded_durations = {"data_wait_time": [1, 2]}
127+
128+
# ensure that we don't log for the first step
129+
callback._log_step_metrics(timer=timer, label="foo", step=1)
130+
logger.log.assert_not_called()
131+
132+
# second step should log
133+
callback._log_step_metrics(timer=timer, label="foo", step=2)
134+
self.assertEqual(logger.log.call_count, 1)
135+
136+
def test_invalid_params(self) -> None:
137+
logger_mock = MagicMock(spec=MetricLogger)
123138
with self.assertRaisesRegex(
124139
ValueError, "log_every_n_steps must be at least 1, got 0"
125140
):
126-
TimeWaitForBatchLogger(
127-
logger=MagicMock(spec=MetricLogger), log_every_n_steps=0
128-
)
141+
TimeWaitForBatchLogger(logger=logger_mock, log_every_n_steps=0)
142+
143+
with self.assertRaisesRegex(
144+
ValueError, "warmup_steps must be at least 0, got -1"
145+
):
146+
TimeWaitForBatchLogger(logger=logger_mock, warmup_steps=-1)

torchtnt/framework/callbacks/time_wait_for_batch_logger.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,16 @@ class TimeWaitForBatchLogger(Callback):
2525
Args:
2626
logger: Either a subclass of :class:`torchtnt.utils.loggers.logger.MetricLogger`
2727
or a :class:`torch.utils.tensorboard.SummaryWriter` instance.
28-
log_every_n_steps: an optional int to control the log frequency
28+
log_every_n_steps: an int to control the log frequency. Default is 1.
29+
warmup_steps: an int to control the number of warmup steps. We will start logging only after the amount of warmup steps were completed. Default is 0.
2930
"""
3031

3132
def __init__(
3233
self,
3334
logger: Union[MetricLogger, SummaryWriter],
35+
*,
3436
log_every_n_steps: int = 1,
37+
warmup_steps: int = 0,
3538
) -> None:
3639
self._logger = logger
3740
if log_every_n_steps < 1:
@@ -40,6 +43,10 @@ def __init__(
4043
)
4144
self._log_every_n_steps = log_every_n_steps
4245

46+
if warmup_steps < 0:
47+
raise ValueError(f"warmup_steps must be at least 0, got {warmup_steps}")
48+
self._warmup_steps = warmup_steps
49+
4350
@rank_zero_fn
4451
def _log_step_metrics(
4552
self,
@@ -48,6 +55,9 @@ def _log_step_metrics(
4855
label: str,
4956
step: int,
5057
) -> None:
58+
if step <= self._warmup_steps:
59+
return
60+
5161
if step % self._log_every_n_steps != 0:
5262
return
5363

0 commit comments

Comments
 (0)