Skip to content

Commit c7095bf

Browse files
galrotemfacebook-github-bot
authored andcommitted
fix iteration time logger logging steps (#786)
Summary: Pull Request resolved: #786 Make sure all steps are logged with the right value Reviewed By: anshulverma Differential Revision: D56199868 fbshipit-source-id: 69c6088e75c9af79d91547b094e5d8a6f7c3cfaf
1 parent 6de95a5 commit c7095bf

File tree

2 files changed

+89
-4
lines changed

2 files changed

+89
-4
lines changed

tests/framework/callbacks/test_iteration_time_logger.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,22 @@
1010
import unittest
1111
from unittest.mock import call, MagicMock
1212

13+
import torch
14+
from pyre_extensions import none_throws
15+
1316
from torch.utils.tensorboard import SummaryWriter
17+
from torchtnt.framework._callback_handler import CallbackHandler
1418
from torchtnt.framework._test_utils import (
19+
DummyAutoUnit,
1520
DummyEvalUnit,
1621
DummyPredictUnit,
1722
DummyTrainUnit,
1823
generate_random_dataloader,
1924
)
2025
from torchtnt.framework.callbacks.iteration_time_logger import IterationTimeLogger
2126

22-
from torchtnt.framework.state import State
23-
from torchtnt.framework.train import train
27+
from torchtnt.framework.state import EntryPoint, PhaseState, State
28+
from torchtnt.framework.train import _train_impl, train
2429
from torchtnt.utils.loggers.logger import MetricLogger
2530

2631

@@ -68,12 +73,12 @@ def test_iteration_time_logger_test_on_train_step_end(self) -> None:
6873
call(
6974
"Train Iteration Time (seconds)",
7075
6.0, # the average of the last 4 numbers is 6
71-
2, # after incrementing twice, step should be 2
76+
1, # at on_train_step_end we report for step-1, we incremented twice so value should be 1
7277
),
7378
call(
7479
"Prediction Iteration Time (seconds)",
7580
16.0, # the average of the last 4 numbers is 16
76-
2, # after incrementing twice, step should be 2
81+
1, # at on_predict_step_end we report for step-1, we incremented twice so value should be 1
7782
),
7883
]
7984
)
@@ -93,6 +98,58 @@ def test_with_train_epoch(self) -> None:
9398
# 2 epochs, 6 iterations each, logging every third step
9499
self.assertEqual(logger.log.call_count, 4)
95100

101+
def test_comparing_step_logging_time(self) -> None:
102+
"""
103+
Test IterationTimeLogger callback and compare reported time to collected time
104+
"""
105+
106+
my_auto_unit = DummyAutoUnit(module=torch.nn.Linear(2, 2))
107+
logger = MagicMock(spec=MetricLogger)
108+
iteration_time_logger = IterationTimeLogger(
109+
logger, moving_avg_window=1, log_every_n_steps=1
110+
)
111+
dataloader = generate_random_dataloader(
112+
num_samples=8, input_dim=2, batch_size=2
113+
)
114+
state = State(
115+
entry_point=EntryPoint.FIT,
116+
train_state=PhaseState(
117+
dataloader=dataloader,
118+
max_epochs=2,
119+
max_steps_per_epoch=2,
120+
),
121+
eval_state=PhaseState(
122+
dataloader=dataloader,
123+
max_steps_per_epoch=2,
124+
evaluate_every_n_epochs=1,
125+
),
126+
)
127+
128+
# we want to be able to compare the logging value to the state, so we need to create state manually and
129+
# call _train_impl. This would have been similar to calling fit() and getting the state as a ret value
130+
131+
_train_impl(state, my_auto_unit, CallbackHandler([iteration_time_logger]))
132+
train_iteration_timer = none_throws(
133+
state.train_state
134+
).iteration_timer.recorded_durations["train_iteration_time"]
135+
eval_iteration_timer = none_throws(
136+
state.eval_state
137+
).iteration_timer.recorded_durations["eval_iteration_time"]
138+
139+
expected_training_iteration_time_calls = [
140+
call("Train Iteration Time (seconds)", train_iteration_timer[i], i + 1)
141+
for i in range(4)
142+
]
143+
expected_eval_iteration_time_calls = [
144+
call("Eval Iteration Time (seconds)", eval_iteration_timer[i], i + 1)
145+
for i in range(4)
146+
]
147+
148+
logger.log.assert_has_calls(
149+
expected_training_iteration_time_calls + expected_eval_iteration_time_calls,
150+
any_order=True,
151+
)
152+
96153
def test_with_summary_writer(self) -> None:
97154
"""
98155
Test IterationTimeLogger callback with train entry point and SummaryWriter

torchtnt/framework/callbacks/iteration_time_logger.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,15 @@ def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
8787
self._log_step_metrics(
8888
"train_iteration_time",
8989
timer,
90+
# on_train_step_end happens after the num steps is incremented, but before the timer list is populated,
91+
# so it logs for step-1
92+
unit.train_progress.num_steps_completed - 1,
93+
)
94+
95+
def on_train_end(self, state: State, unit: TTrainUnit) -> None:
96+
self._log_step_metrics(
97+
"train_iteration_time",
98+
none_throws(state.train_state).iteration_timer,
9099
unit.train_progress.num_steps_completed,
91100
)
92101

@@ -95,10 +104,29 @@ def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None:
95104
self._log_step_metrics(
96105
"eval_iteration_time",
97106
timer,
107+
# on_eval_step_end happens after the num steps is incremented, but before the timer list is populated,
108+
# so it logs for step-1
109+
unit.eval_progress.num_steps_completed - 1,
110+
)
111+
112+
def on_eval_end(self, state: State, unit: TEvalUnit) -> None:
113+
self._log_step_metrics(
114+
"eval_iteration_time",
115+
none_throws(state.eval_state).iteration_timer,
98116
unit.eval_progress.num_steps_completed,
99117
)
100118

101119
def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None:
120+
timer = none_throws(state.predict_state).iteration_timer
121+
self._log_step_metrics(
122+
"predict_iteration_time",
123+
timer,
124+
# on_predict_step_end happens after the num steps is incremented, but before the timer list is populated,
125+
# so it logs for step-1
126+
unit.predict_progress.num_steps_completed - 1,
127+
)
128+
129+
def on_predict_end(self, state: State, unit: TPredictUnit) -> None:
102130
timer = none_throws(state.predict_state).iteration_timer
103131
self._log_step_metrics(
104132
"predict_iteration_time",

0 commit comments

Comments
 (0)