Skip to content

Commit 077c1fe

Browse files
galrotemfacebook-github-bot
authored andcommitted
add warmup steps to throughput logger (#840)
Summary: Pull Request resolved: #840 Adding warmup steps to different performance loggers Reviewed By: diego-urgell Differential Revision: D57596034 fbshipit-source-id: ceeb60ae08b7bae33f69525816407f36d0510bfc
1 parent ba310cf commit 077c1fe

File tree

2 files changed

+44
-7
lines changed

2 files changed

+44
-7
lines changed

tests/framework/callbacks/test_throughput_logger.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
class ThroughputLoggerTest(unittest.TestCase):
3333
def test_maybe_log_for_step(self) -> None:
3434
logger = MagicMock(spec=MetricLogger)
35-
throughput_logger = ThroughputLogger(logger, {"Batches": 1, "Items": 32}, 1)
35+
throughput_logger = ThroughputLogger(logger, {"Batches": 1, "Items": 32})
3636
phase_state = PhaseState(dataloader=[])
3737
phase_state.iteration_timer.recorded_durations = {
3838
"data_wait_time": [1, 4],
@@ -75,7 +75,7 @@ def test_maybe_log_for_step(self) -> None:
7575

7676
def test_maybe_log_for_step_early_return(self) -> None:
7777
logger = MagicMock(spec=MetricLogger)
78-
throughput_logger = ThroughputLogger(logger, {"Batches": 1}, 1)
78+
throughput_logger = ThroughputLogger(logger, {"Batches": 1})
7979
phase_state = PhaseState(dataloader=[])
8080
recorded_durations_dict = {
8181
"data_wait_time": [0.0, 4.0],
@@ -101,7 +101,9 @@ def test_maybe_log_for_step_early_return(self) -> None:
101101

102102
# step_logging_for % log_every_n_steps != 0
103103
recorded_durations_dict["data_wait_time"] = [1.0, 2.0]
104-
throughput_logger = ThroughputLogger(logger, {"Batches": 1}, 2)
104+
throughput_logger = ThroughputLogger(
105+
logger, {"Batches": 1}, log_every_n_steps=2
106+
)
105107
throughput_logger._maybe_log_for_step(state, step_logging_for=1)
106108
logger.log.assert_not_called()
107109

@@ -330,17 +332,40 @@ def test_epoch_logging_time(self) -> None:
330332
any_order=True,
331333
)
332334

335+
def test_warmup_steps(self) -> None:
336+
logger = MagicMock(spec=MetricLogger)
337+
throughput_logger = ThroughputLogger(
338+
logger, {"Batches": 1, "Items": 32}, warmup_steps=1
339+
)
340+
phase_state = PhaseState(dataloader=[])
341+
phase_state.iteration_timer.recorded_durations = {
342+
"data_wait_time": [1, 4],
343+
"train_iteration_time": [3],
344+
}
345+
state = State(entry_point=EntryPoint.TRAIN, train_state=phase_state)
346+
347+
throughput_logger._maybe_log_for_step(state, 1)
348+
logger.log.assert_not_called()
349+
350+
throughput_logger._maybe_log_for_step(state, 2)
351+
self.assertEqual(logger.log.call_count, 2)
352+
333353
def test_input_validation(self) -> None:
334354
logger = MagicMock(spec=MetricLogger)
335355
with self.assertRaisesRegex(ValueError, "throughput_per_batch cannot be empty"):
336-
ThroughputLogger(logger, {}, 1)
356+
ThroughputLogger(logger, {})
337357

338358
with self.assertRaisesRegex(
339359
ValueError, "throughput_per_batch item Batches must be at least 1, got -1"
340360
):
341-
ThroughputLogger(logger, {"Queries": 8, "Batches": -1}, 1)
361+
ThroughputLogger(logger, {"Queries": 8, "Batches": -1})
342362

343363
with self.assertRaisesRegex(
344364
ValueError, "log_every_n_steps must be at least 1, got 0"
345365
):
346-
ThroughputLogger(logger, {"Batches": 1}, 0)
366+
ThroughputLogger(logger, {"Batches": 1}, log_every_n_steps=0)
367+
368+
with self.assertRaisesRegex(
369+
ValueError, "warmup_steps must be at least 0, got -1"
370+
):
371+
ThroughputLogger(logger, {"Batches": 1}, warmup_steps=-1)

torchtnt/framework/callbacks/throughput_logger.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ class ThroughputLogger(Callback):
4848
For instace, a user can pass in {Batches: 1, Queries: 32} which will visualize two charts -
4949
one for Batches per second and one for Queries per second.
5050
As an example, if each of your batches is of type: {data: torch.Size([16, 8, 8]), labels: torch.Size([16,1])}, then you could pass {Queries: 16}.
51-
log_every_n_steps: an optional int to control the log frequency.
51+
log_every_n_steps: an int to control the log frequency. Default is 1.
52+
warmup_steps: an int to control the number of warmup steps. We will start logging only after the amount of warmup steps were completed. Default is 0.
5253
5354
Note:
5455
The values reported are only for rank 0.
@@ -59,7 +60,9 @@ def __init__(
5960
self,
6061
logger: MetricLogger,
6162
throughput_per_batch: Mapping[str, int],
63+
*,
6264
log_every_n_steps: int = 1,
65+
warmup_steps: int = 0,
6366
) -> None:
6467
self._logger = logger
6568

@@ -80,6 +83,12 @@ def __init__(
8083
)
8184

8285
self._log_every_n_steps = log_every_n_steps
86+
87+
if warmup_steps < 0:
88+
raise ValueError(f"warmup_steps must be at least 0, got {warmup_steps}")
89+
90+
self._warmup_steps = warmup_steps
91+
8392
self._epoch_start_times: Dict[ActivePhase, float] = {}
8493
self._steps_in_epoch: Dict[ActivePhase, int] = defaultdict(int)
8594

@@ -154,6 +163,9 @@ def _maybe_log_for_step(
154163
*,
155164
is_step_end_hook: bool = True,
156165
) -> None:
166+
if step_logging_for <= self._warmup_steps:
167+
return
168+
157169
if step_logging_for % self._log_every_n_steps != 0:
158170
return
159171

0 commit comments

Comments
 (0)