Skip to content

Commit 249eea3

Browse files
galrotemfacebook-github-bot
authored andcommitted
iteration time logger support for MetricLogger (#759)
Summary: Pull Request resolved: #759 Add support for generic MetricLogger in IterationTimeLogger Reviewed By: JKSenthil Differential Revision: D55333343 fbshipit-source-id: 55b25d13d2d87e5037e3bb44bbcd0cceb466c3ac
1 parent c2dcee9 commit 249eea3

File tree

2 files changed

+53
-45
lines changed

2 files changed

+53
-45
lines changed

tests/framework/callbacks/test_iteration_time_logger.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,12 @@
2121

2222
from torchtnt.framework.state import State
2323
from torchtnt.framework.train import train
24-
from torchtnt.utils.loggers import TensorBoardLogger
24+
from torchtnt.utils.loggers.logger import MetricLogger
2525

2626

2727
class IterationTimeLoggerTest(unittest.TestCase):
2828
def test_iteration_time_logger_test_on_train_step_end(self) -> None:
29-
logger = MagicMock(spec=TensorBoardLogger)
30-
logger.writer = MagicMock(spec=SummaryWriter)
29+
logger = MagicMock(spec=MetricLogger)
3130
state = MagicMock(spec=State)
3231

3332
# Test that the recorded times are tracked separately and that we properly
@@ -64,7 +63,7 @@ def test_iteration_time_logger_test_on_train_step_end(self) -> None:
6463
callback.on_eval_step_end(state, eval_unit)
6564
callback.on_predict_step_end(state, predict_unit)
6665

67-
logger.writer.add_scalar.assert_has_calls(
66+
logger.log.assert_has_calls(
6867
[
6968
call(
7069
"Train Iteration Time (seconds)",
@@ -85,12 +84,26 @@ def test_with_train_epoch(self) -> None:
8584
"""
8685

8786
my_unit = DummyTrainUnit(input_dim=2)
88-
logger = MagicMock(spec=TensorBoardLogger)
89-
logger.writer = MagicMock(spec=SummaryWriter)
87+
logger = MagicMock(spec=MetricLogger)
9088
callback = IterationTimeLogger(logger, moving_avg_window=1, log_every_n_steps=3)
9189
dataloader = generate_random_dataloader(
9290
num_samples=12, input_dim=2, batch_size=2
9391
)
9492
train(my_unit, dataloader, max_epochs=2, callbacks=[callback])
9593
# 2 epochs, 6 iterations each, logging every third step
96-
self.assertEqual(logger.writer.add_scalar.call_count, 4)
94+
self.assertEqual(logger.log.call_count, 4)
95+
96+
def test_with_summary_writer(self) -> None:
97+
"""
98+
Test IterationTimeLogger callback with train entry point and SummaryWriter
99+
"""
100+
101+
my_unit = DummyTrainUnit(input_dim=2)
102+
logger = MagicMock(spec=SummaryWriter)
103+
callback = IterationTimeLogger(logger, moving_avg_window=1, log_every_n_steps=3)
104+
dataloader = generate_random_dataloader(
105+
num_samples=12, input_dim=2, batch_size=2
106+
)
107+
train(my_unit, dataloader, max_epochs=2, callbacks=[callback])
108+
# 2 epochs, 6 iterations each, logging every third step
109+
self.assertEqual(logger.add_scalar.call_count, 4)

torchtnt/framework/callbacks/iteration_time_logger.py

Lines changed: 33 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@
77
# pyre-strict
88

99

10-
from typing import Optional, Union
10+
from typing import cast, Optional, Union
1111

1212
from pyre_extensions import none_throws
1313
from torch.utils.tensorboard import SummaryWriter
1414

1515
from torchtnt.framework.callback import Callback
1616
from torchtnt.framework.state import State
1717
from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit
18-
from torchtnt.utils.distributed import get_global_rank
19-
from torchtnt.utils.loggers.tensorboard import TensorBoardLogger
18+
from torchtnt.utils.distributed import rank_zero_fn
19+
from torchtnt.utils.loggers.logger import MetricLogger
2020
from torchtnt.utils.timer import TimerProtocol
2121

2222

@@ -35,23 +35,17 @@ class IterationTimeLogger(Callback):
3535

3636
def __init__(
3737
self,
38-
logger: Union[TensorBoardLogger, SummaryWriter],
38+
logger: Union[MetricLogger, SummaryWriter],
3939
moving_avg_window: int = 1,
4040
log_every_n_steps: int = 1,
4141
) -> None:
42-
if isinstance(logger, TensorBoardLogger):
43-
logger = logger.writer
44-
45-
if get_global_rank() == 0: # only write from the main rank
46-
self._writer = none_throws(
47-
logger, "TensorBoardLogger.writer should not be None"
48-
)
42+
self._logger = logger
4943
self.moving_avg_window = moving_avg_window
5044
self.log_every_n_steps = log_every_n_steps
5145

46+
@rank_zero_fn
5247
def _log_step_metrics(
5348
self,
54-
writer: SummaryWriter,
5549
metric_label: str,
5650
iteration_timer: TimerProtocol,
5751
step_logging_for: int,
@@ -75,38 +69,39 @@ def _log_step_metrics(
7569
return
7670

7771
last_n_values = time_list[-self.moving_avg_window :]
78-
writer.add_scalar(
79-
human_metric_names[metric_label],
80-
sum(last_n_values) / len(last_n_values),
81-
step_logging_for,
82-
)
72+
if isinstance(self._logger, SummaryWriter):
73+
self._logger.add_scalar(
74+
human_metric_names[metric_label],
75+
sum(last_n_values) / len(last_n_values),
76+
step_logging_for,
77+
)
78+
else:
79+
cast(MetricLogger, self._logger).log(
80+
human_metric_names[metric_label],
81+
sum(last_n_values) / len(last_n_values),
82+
step_logging_for,
83+
)
8384

8485
def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
8586
timer = none_throws(state.train_state).iteration_timer
86-
if writer := self._writer:
87-
self._log_step_metrics(
88-
writer,
89-
"train_iteration_time",
90-
timer,
91-
unit.train_progress.num_steps_completed,
92-
)
87+
self._log_step_metrics(
88+
"train_iteration_time",
89+
timer,
90+
unit.train_progress.num_steps_completed,
91+
)
9392

9493
def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None:
9594
timer = none_throws(state.eval_state).iteration_timer
96-
if writer := self._writer:
97-
self._log_step_metrics(
98-
writer,
99-
"eval_iteration_time",
100-
timer,
101-
unit.eval_progress.num_steps_completed,
102-
)
95+
self._log_step_metrics(
96+
"eval_iteration_time",
97+
timer,
98+
unit.eval_progress.num_steps_completed,
99+
)
103100

104101
def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None:
105102
timer = none_throws(state.predict_state).iteration_timer
106-
if writer := self._writer:
107-
self._log_step_metrics(
108-
writer,
109-
"predict_iteration_time",
110-
timer,
111-
unit.predict_progress.num_steps_completed,
112-
)
103+
self._log_step_metrics(
104+
"predict_iteration_time",
105+
timer,
106+
unit.predict_progress.num_steps_completed,
107+
)

0 commit comments

Comments
 (0)