Skip to content

Commit 41b9918

Browse files
galrotemfacebook-github-bot
authored andcommitted
add warmup steps to iteration time logger (#837)
Summary: Pull Request resolved: #837 Adding warmup steps to different performance loggers Reviewed By: diego-urgell Differential Revision: D57595957 fbshipit-source-id: e92bf0345ed93ea944cb751bdfd162535c69c332
1 parent a0e6830 commit 41b9918

File tree

2 files changed

+54
-2
lines changed

2 files changed

+54
-2
lines changed

tests/framework/callbacks/test_iteration_time_logger.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from torchtnt.framework.state import EntryPoint, PhaseState, State
2828
from torchtnt.framework.train import _train_impl, train
2929
from torchtnt.utils.loggers.logger import MetricLogger
30+
from torchtnt.utils.timer import Timer
3031

3132

3233
class IterationTimeLoggerTest(unittest.TestCase):
@@ -164,3 +165,34 @@ def test_with_summary_writer(self) -> None:
164165
train(my_unit, dataloader, max_epochs=2, callbacks=[callback])
165166
# 2 epochs, 6 iterations each, logging every third step
166167
self.assertEqual(logger.add_scalar.call_count, 4)
168+
169+
def test_warmup_steps(self) -> None:
170+
logger = MagicMock(spec=MetricLogger)
171+
callback = IterationTimeLogger(logger=logger, warmup_steps=1)
172+
timer = Timer()
173+
timer.recorded_durations = {"train_iteration_time": [1, 2]}
174+
175+
# ensure that we don't log for the first step
176+
callback._log_step_metrics("train_iteration_time", timer, 1)
177+
logger.log.assert_not_called()
178+
179+
# second step should log
180+
callback._log_step_metrics("train_iteration_time", timer, 2)
181+
self.assertEqual(logger.log.call_count, 1)
182+
183+
def test_invalid_params(self) -> None:
184+
logger = MagicMock(spec=MetricLogger)
185+
with self.assertRaisesRegex(
186+
ValueError, "moving_avg_window must be at least 1, got 0"
187+
):
188+
IterationTimeLogger(logger=logger, moving_avg_window=0)
189+
190+
with self.assertRaisesRegex(
191+
ValueError, "log_every_n_steps must be at least 1, got -1"
192+
):
193+
IterationTimeLogger(logger=logger, log_every_n_steps=-1)
194+
195+
with self.assertRaisesRegex(
196+
ValueError, "warmup_steps must be at least 0, got -1"
197+
):
198+
IterationTimeLogger(logger=logger, warmup_steps=-1)

torchtnt/framework/callbacks/iteration_time_logger.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,37 @@ class IterationTimeLogger(Callback):
2727
Args:
2828
logger: Either a :class:`torchtnt.loggers.tensorboard.TensorBoardLogger`
2929
or a :class:`torch.utils.tensorboard.SummaryWriter` instance.
30-
moving_avg_window: an optional int to control the moving average window
31-
log_every_n_steps: an optional int to control the log frequency
30+
moving_avg_window: an int to control the moving average window. Default is 1.
31+
log_every_n_steps: an int to control the log frequency. Default is 1.
32+
warmup_steps: an int to control the number of warmup steps. We will start logging only after the amount of warmup steps were completed. Default is 0.
3233
"""
3334

3435
def __init__(
3536
self,
3637
logger: Union[MetricLogger, SummaryWriter],
38+
*,
3739
moving_avg_window: int = 1,
3840
log_every_n_steps: int = 1,
41+
warmup_steps: int = 0,
3942
) -> None:
4043
self._logger = logger
44+
45+
if moving_avg_window < 1:
46+
raise ValueError(
47+
f"moving_avg_window must be at least 1, got {moving_avg_window}"
48+
)
4149
self.moving_avg_window = moving_avg_window
50+
51+
if log_every_n_steps < 1:
52+
raise ValueError(
53+
f"log_every_n_steps must be at least 1, got {log_every_n_steps}"
54+
)
4255
self.log_every_n_steps = log_every_n_steps
4356

57+
if warmup_steps < 0:
58+
raise ValueError(f"warmup_steps must be at least 0, got {warmup_steps}")
59+
self.warmup_steps = warmup_steps
60+
4461
@rank_zero_fn
4562
def _log_step_metrics(
4663
self,
@@ -53,6 +70,9 @@ def _log_step_metrics(
5370
was configured.
5471
5572
"""
73+
if step_logging_for <= self.warmup_steps:
74+
return
75+
5676
if step_logging_for % self.log_every_n_steps != 0:
5777
return
5878

0 commit comments

Comments
 (0)