Skip to content

Commit e3ffa1f

Browse files
galrotemfacebook-github-bot
authored andcommitted
throughput logger
Summary: Introduce throughput logger. Internal # Context The stack adds a throughput logger that can be used to log generic throughput per second, based on user config. This diff will add the throughput logger including logging per step. The next diff will add throughput on an epoch granularity. # This diff Adds throughput logger: 1. It uses the already collected iteration time and data wait time timers to get the step time. 2. It's slightly confusing but when `on_train_step_end` is called, the iteration time timer hasn't been populated yet, while the data wait time timer has been populated, hence there's a difference between the two when we are logging for (step-1). On the `on_train_end` both lists are fully populated so we can just use the last element safely. Reviewed By: JKSenthil Differential Revision: D56496451 fbshipit-source-id: e6b119b1a42264d3e764da86e853deb03bd1cf82
1 parent 7737e13 commit e3ffa1f

File tree

4 files changed

+405
-1
lines changed

4 files changed

+405
-1
lines changed

docs/source/framework/callbacks.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ We offer several pre-written callbacks which are ready to be used out of the box
2222
BaseCSVWriter
2323
EarlyStopping
2424
GarbageCollector
25+
IterationTimeLogger
2526
Lambda
2627
LearningRateMonitor
2728
MemorySnapshot
@@ -33,7 +34,7 @@ We offer several pre-written callbacks which are ready to be used out of the box
3334
TensorBoardParameterMonitor
3435
TimeLimitInterrupter
3536
TimeWaitForBatchLogger
36-
IterationTimeLogger
37+
ThroughputLogger
3738
TorchSnapshotSaver
3839
TQDMProgressBar
3940
TrainProgressMonitor
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import unittest
11+
from unittest.mock import ANY, call, MagicMock
12+
13+
import torch
14+
from pyre_extensions import none_throws
15+
16+
from torchtnt.framework._callback_handler import CallbackHandler
17+
from torchtnt.framework._test_utils import (
18+
DummyAutoUnit,
19+
DummyPredictUnit,
20+
generate_random_dataloader,
21+
)
22+
from torchtnt.framework.callbacks.throughput_logger import ThroughputLogger
23+
from torchtnt.framework.predict import predict
24+
25+
from torchtnt.framework.state import EntryPoint, PhaseState, State
26+
from torchtnt.framework.train import _train_impl
27+
from torchtnt.utils.loggers.logger import MetricLogger
28+
29+
30+
class ThroughputLoggerTest(unittest.TestCase):
31+
def test_maybe_log_for_step(self) -> None:
32+
logger = MagicMock(spec=MetricLogger)
33+
throughput_logger = ThroughputLogger(logger, {"Batches": 1, "Items": 32}, 1)
34+
phase_state = PhaseState(dataloader=[])
35+
phase_state.iteration_timer.recorded_durations = {
36+
"data_wait_time": [1, 4],
37+
"train_iteration_time": [3],
38+
}
39+
state = State(entry_point=EntryPoint.TRAIN, train_state=phase_state)
40+
throughput_logger._maybe_log_for_step(state, 1)
41+
logger.log.assert_has_calls(
42+
[
43+
call(
44+
"Train: Batches per second (step granularity)",
45+
0.25, # 1/(1+3)
46+
1,
47+
),
48+
call(
49+
"Train: Items per second (step granularity)",
50+
8, # 32/(1+3)
51+
1,
52+
),
53+
],
54+
any_order=True,
55+
)
56+
logger.log.reset_mock()
57+
phase_state.iteration_timer.recorded_durations["train_iteration_time"].append(4)
58+
throughput_logger._maybe_log_for_step(state, 2, is_step_end_hook=False)
59+
logger.log.assert_has_calls(
60+
[
61+
call(
62+
"Train: Batches per second (step granularity)",
63+
0.125, # 1/(4+4)
64+
2,
65+
),
66+
call(
67+
"Train: Items per second (step granularity)",
68+
4, # 32/(4+4)
69+
2,
70+
),
71+
]
72+
)
73+
74+
def test_maybe_log_for_step_early_return(self) -> None:
75+
logger = MagicMock(spec=MetricLogger)
76+
throughput_logger = ThroughputLogger(logger, {"Batches": 1}, 1)
77+
phase_state = PhaseState(dataloader=[])
78+
recorded_durations_dict = {
79+
"data_wait_time": [0.0, 4.0],
80+
"train_iteration_time": [0.0],
81+
}
82+
# total_time <= 0
83+
phase_state.iteration_timer.recorded_durations = recorded_durations_dict
84+
state = State(entry_point=EntryPoint.TRAIN, train_state=phase_state)
85+
throughput_logger._maybe_log_for_step(state, step_logging_for=1)
86+
logger.log.assert_not_called()
87+
88+
# empty iteration_time_list
89+
recorded_durations_dict["data_wait_time"] = [1.0, 2.0]
90+
recorded_durations_dict["train_iteration_time"] = []
91+
throughput_logger._maybe_log_for_step(state, step_logging_for=1)
92+
logger.log.assert_not_called()
93+
94+
# small data_wait_time list
95+
recorded_durations_dict["data_wait_time"] = [1.0]
96+
recorded_durations_dict["train_iteration_time"] = [1.0]
97+
throughput_logger._maybe_log_for_step(state, step_logging_for=1)
98+
logger.log.assert_not_called()
99+
100+
# step_logging_for % log_every_n_steps != 0
101+
recorded_durations_dict["data_wait_time"] = [1.0, 2.0]
102+
throughput_logger = ThroughputLogger(logger, {"Batches": 1}, 2)
103+
throughput_logger._maybe_log_for_step(state, step_logging_for=1)
104+
logger.log.assert_not_called()
105+
106+
def test_with_comparing_time(self) -> None:
107+
logger = MagicMock(spec=MetricLogger)
108+
dataloader = generate_random_dataloader(
109+
num_samples=8, input_dim=2, batch_size=2
110+
)
111+
state = State(
112+
entry_point=EntryPoint.FIT,
113+
train_state=PhaseState(
114+
dataloader=dataloader,
115+
max_epochs=2,
116+
max_steps_per_epoch=2,
117+
),
118+
eval_state=PhaseState(
119+
dataloader=dataloader,
120+
max_steps_per_epoch=2,
121+
evaluate_every_n_epochs=2,
122+
),
123+
)
124+
125+
# we want to be able to compare the logging value to the state, so we need to create state manually and
126+
# call _train_impl. This would have been similar to calling fit() and getting the state as a ret value
127+
_train_impl(
128+
state,
129+
DummyAutoUnit(module=torch.nn.Linear(2, 2)),
130+
CallbackHandler(
131+
[
132+
ThroughputLogger(
133+
logger=logger,
134+
throughput_per_batch={"Batches": 1, "Queries": 8},
135+
log_every_n_steps=1,
136+
)
137+
],
138+
),
139+
)
140+
141+
train_iteration_times = none_throws(
142+
state.train_state
143+
).iteration_timer.recorded_durations["train_iteration_time"]
144+
train_twfb_times = none_throws(
145+
state.train_state
146+
).iteration_timer.recorded_durations["data_wait_time"]
147+
eval_iteration_times = none_throws(
148+
state.eval_state
149+
).iteration_timer.recorded_durations["eval_iteration_time"]
150+
eval_twfb_times = none_throws(
151+
state.eval_state
152+
).iteration_timer.recorded_durations["data_wait_time"]
153+
154+
self.assertEqual(len(train_iteration_times), 4)
155+
self.assertEqual(len(train_twfb_times), 4)
156+
self.assertEqual(len(eval_iteration_times), 2)
157+
self.assertEqual(len(eval_twfb_times), 2)
158+
159+
train_step_times = [
160+
train_iteration_times[i] + train_twfb_times[i] for i in range(4)
161+
]
162+
eval_step_times = [
163+
eval_iteration_times[i] + eval_twfb_times[i] for i in range(2)
164+
]
165+
self.assertEqual(
166+
logger.log.call_count, 12
167+
) # 8 train (2epochs x 2steps x 2items), 4 eval (1x2x2)
168+
train_batches_step_logs = [
169+
call(
170+
"Train: Batches per second (step granularity)",
171+
1 / (train_step_times[i]),
172+
i + 1,
173+
)
174+
for i in range(4)
175+
]
176+
train_queries_step_logs = [
177+
call(
178+
"Train: Queries per second (step granularity)",
179+
8 / (train_step_times[i]),
180+
i + 1,
181+
)
182+
for i in range(4)
183+
]
184+
eval_batches_step_logs = [
185+
call(
186+
"Eval: Batches per second (step granularity)",
187+
1 / (eval_step_times[i]),
188+
i + 1,
189+
)
190+
for i in range(2)
191+
]
192+
eval_queries_step_logs = [
193+
call(
194+
"Eval: Queries per second (step granularity)",
195+
8 / (eval_step_times[i]),
196+
i + 1,
197+
)
198+
for i in range(2)
199+
]
200+
logger.log.assert_has_calls(
201+
train_batches_step_logs
202+
+ train_queries_step_logs
203+
+ eval_batches_step_logs
204+
+ eval_queries_step_logs,
205+
any_order=True,
206+
)
207+
208+
def test_with_predict(self) -> None:
209+
logger = MagicMock(spec=MetricLogger)
210+
predict(
211+
DummyPredictUnit(input_dim=2),
212+
generate_random_dataloader(num_samples=8, input_dim=2, batch_size=2),
213+
max_steps_per_epoch=1,
214+
callbacks=[
215+
ThroughputLogger(
216+
logger=logger,
217+
throughput_per_batch={"Batches": 1},
218+
log_every_n_steps=1,
219+
)
220+
],
221+
)
222+
logger.log.assert_has_calls(
223+
[
224+
call(
225+
"Predict: Batches per second (step granularity)",
226+
ANY,
227+
1,
228+
)
229+
],
230+
)
231+
232+
def test_input_validation(self) -> None:
233+
logger = MagicMock(spec=MetricLogger)
234+
with self.assertRaisesRegex(ValueError, "throughput_per_batch cannot be empty"):
235+
ThroughputLogger(logger, {}, 1)
236+
237+
with self.assertRaisesRegex(
238+
ValueError, "throughput_per_batch item Batches must be at least 1, got -1"
239+
):
240+
ThroughputLogger(logger, {"Queries": 8, "Batches": -1}, 1)
241+
242+
with self.assertRaisesRegex(
243+
ValueError, "log_every_n_steps must be at least 1, got 0"
244+
):
245+
ThroughputLogger(logger, {"Batches": 1}, 0)

torchtnt/framework/callbacks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .slow_rank_detector import SlowRankDetector
2222
from .system_resources_monitor import SystemResourcesMonitor
2323
from .tensorboard_parameter_monitor import TensorBoardParameterMonitor
24+
from .throughput_logger import ThroughputLogger
2425
from .time_limit_interrupter import TimeLimitInterrupter
2526
from .time_wait_for_batch_logger import TimeWaitForBatchLogger
2627
from .torch_compile import TorchCompile
@@ -43,6 +44,7 @@
4344
"SlowRankDetector",
4445
"SystemResourcesMonitor",
4546
"TensorBoardParameterMonitor",
47+
"ThroughputLogger",
4648
"TimeLimitInterrupter",
4749
"TimeWaitForBatchLogger",
4850
"TorchCompile",

0 commit comments

Comments
 (0)