32
32
class ThroughputLoggerTest (unittest .TestCase ):
33
33
def test_maybe_log_for_step (self ) -> None :
34
34
logger = MagicMock (spec = MetricLogger )
35
- throughput_logger = ThroughputLogger (logger , {"Batches" : 1 , "Items" : 32 }, 1 )
35
+ throughput_logger = ThroughputLogger (logger , {"Batches" : 1 , "Items" : 32 })
36
36
phase_state = PhaseState (dataloader = [])
37
37
phase_state .iteration_timer .recorded_durations = {
38
38
"data_wait_time" : [1 , 4 ],
@@ -75,7 +75,7 @@ def test_maybe_log_for_step(self) -> None:
75
75
76
76
def test_maybe_log_for_step_early_return (self ) -> None :
77
77
logger = MagicMock (spec = MetricLogger )
78
- throughput_logger = ThroughputLogger (logger , {"Batches" : 1 }, 1 )
78
+ throughput_logger = ThroughputLogger (logger , {"Batches" : 1 })
79
79
phase_state = PhaseState (dataloader = [])
80
80
recorded_durations_dict = {
81
81
"data_wait_time" : [0.0 , 4.0 ],
@@ -101,7 +101,9 @@ def test_maybe_log_for_step_early_return(self) -> None:
101
101
102
102
# step_logging_for % log_every_n_steps != 0
103
103
recorded_durations_dict ["data_wait_time" ] = [1.0 , 2.0 ]
104
- throughput_logger = ThroughputLogger (logger , {"Batches" : 1 }, 2 )
104
+ throughput_logger = ThroughputLogger (
105
+ logger , {"Batches" : 1 }, log_every_n_steps = 2
106
+ )
105
107
throughput_logger ._maybe_log_for_step (state , step_logging_for = 1 )
106
108
logger .log .assert_not_called ()
107
109
@@ -330,17 +332,40 @@ def test_epoch_logging_time(self) -> None:
330
332
any_order = True ,
331
333
)
332
334
335
+ def test_warmup_steps (self ) -> None :
336
+ logger = MagicMock (spec = MetricLogger )
337
+ throughput_logger = ThroughputLogger (
338
+ logger , {"Batches" : 1 , "Items" : 32 }, warmup_steps = 1
339
+ )
340
+ phase_state = PhaseState (dataloader = [])
341
+ phase_state .iteration_timer .recorded_durations = {
342
+ "data_wait_time" : [1 , 4 ],
343
+ "train_iteration_time" : [3 ],
344
+ }
345
+ state = State (entry_point = EntryPoint .TRAIN , train_state = phase_state )
346
+
347
+ throughput_logger ._maybe_log_for_step (state , 1 )
348
+ logger .log .assert_not_called ()
349
+
350
+ throughput_logger ._maybe_log_for_step (state , 2 )
351
+ self .assertEqual (logger .log .call_count , 2 )
352
+
333
353
def test_input_validation (self ) -> None :
334
354
logger = MagicMock (spec = MetricLogger )
335
355
with self .assertRaisesRegex (ValueError , "throughput_per_batch cannot be empty" ):
336
- ThroughputLogger (logger , {}, 1 )
356
+ ThroughputLogger (logger , {})
337
357
338
358
with self .assertRaisesRegex (
339
359
ValueError , "throughput_per_batch item Batches must be at least 1, got -1"
340
360
):
341
- ThroughputLogger (logger , {"Queries" : 8 , "Batches" : - 1 }, 1 )
361
+ ThroughputLogger (logger , {"Queries" : 8 , "Batches" : - 1 })
342
362
343
363
with self .assertRaisesRegex (
344
364
ValueError , "log_every_n_steps must be at least 1, got 0"
345
365
):
346
- ThroughputLogger (logger , {"Batches" : 1 }, 0 )
366
+ ThroughputLogger (logger , {"Batches" : 1 }, log_every_n_steps = 0 )
367
+
368
+ with self .assertRaisesRegex (
369
+ ValueError , "warmup_steps must be at least 0, got -1"
370
+ ):
371
+ ThroughputLogger (logger , {"Batches" : 1 }, warmup_steps = - 1 )
0 commit comments