10
10
import unittest
11
11
from unittest .mock import call , MagicMock
12
12
13
+ import torch
14
+ from pyre_extensions import none_throws
15
+
13
16
from torch .utils .tensorboard import SummaryWriter
17
+ from torchtnt .framework ._callback_handler import CallbackHandler
14
18
from torchtnt .framework ._test_utils import (
19
+ DummyAutoUnit ,
15
20
DummyEvalUnit ,
16
21
DummyPredictUnit ,
17
22
DummyTrainUnit ,
18
23
generate_random_dataloader ,
19
24
)
20
25
from torchtnt .framework .callbacks .iteration_time_logger import IterationTimeLogger
21
26
22
- from torchtnt .framework .state import State
23
- from torchtnt .framework .train import train
27
+ from torchtnt .framework .state import EntryPoint , PhaseState , State
28
+ from torchtnt .framework .train import _train_impl , train
24
29
from torchtnt .utils .loggers .logger import MetricLogger
25
30
26
31
@@ -68,12 +73,12 @@ def test_iteration_time_logger_test_on_train_step_end(self) -> None:
68
73
call (
69
74
"Train Iteration Time (seconds)" ,
70
75
6.0 , # the average of the last 4 numbers is 6
71
- 2 , # after incrementing twice, step should be 2
76
+ 1 , # at on_train_step_end we report for step-1, we incremented twice so value should be 1
72
77
),
73
78
call (
74
79
"Prediction Iteration Time (seconds)" ,
75
80
16.0 , # the average of the last 4 numbers is 16
76
- 2 , # after incrementing twice, step should be 2
81
+ 1 , # at on_predict_step_end we report for step-1, we incremented twice so value should be 1
77
82
),
78
83
]
79
84
)
@@ -93,6 +98,58 @@ def test_with_train_epoch(self) -> None:
93
98
# 2 epochs, 6 iterations each, logging every third step
94
99
self .assertEqual (logger .log .call_count , 4 )
95
100
101
+ def test_comparing_step_logging_time (self ) -> None :
102
+ """
103
+ Test IterationTimeLogger callback and compare reported time to collected time
104
+ """
105
+
106
+ my_auto_unit = DummyAutoUnit (module = torch .nn .Linear (2 , 2 ))
107
+ logger = MagicMock (spec = MetricLogger )
108
+ iteration_time_logger = IterationTimeLogger (
109
+ logger , moving_avg_window = 1 , log_every_n_steps = 1
110
+ )
111
+ dataloader = generate_random_dataloader (
112
+ num_samples = 8 , input_dim = 2 , batch_size = 2
113
+ )
114
+ state = State (
115
+ entry_point = EntryPoint .FIT ,
116
+ train_state = PhaseState (
117
+ dataloader = dataloader ,
118
+ max_epochs = 2 ,
119
+ max_steps_per_epoch = 2 ,
120
+ ),
121
+ eval_state = PhaseState (
122
+ dataloader = dataloader ,
123
+ max_steps_per_epoch = 2 ,
124
+ evaluate_every_n_epochs = 1 ,
125
+ ),
126
+ )
127
+
128
+ # we want to be able to compare the logging value to the state, so we need to create state manually and
129
+ # call _train_impl. This would have been similar to calling fit() and getting the state as a ret value
130
+
131
+ _train_impl (state , my_auto_unit , CallbackHandler ([iteration_time_logger ]))
132
+ train_iteration_timer = none_throws (
133
+ state .train_state
134
+ ).iteration_timer .recorded_durations ["train_iteration_time" ]
135
+ eval_iteration_timer = none_throws (
136
+ state .eval_state
137
+ ).iteration_timer .recorded_durations ["eval_iteration_time" ]
138
+
139
+ expected_training_iteration_time_calls = [
140
+ call ("Train Iteration Time (seconds)" , train_iteration_timer [i ], i + 1 )
141
+ for i in range (4 )
142
+ ]
143
+ expected_eval_iteration_time_calls = [
144
+ call ("Eval Iteration Time (seconds)" , eval_iteration_timer [i ], i + 1 )
145
+ for i in range (4 )
146
+ ]
147
+
148
+ logger .log .assert_has_calls (
149
+ expected_training_iteration_time_calls + expected_eval_iteration_time_calls ,
150
+ any_order = True ,
151
+ )
152
+
96
153
def test_with_summary_writer (self ) -> None :
97
154
"""
98
155
Test IterationTimeLogger callback with train entry point and SummaryWriter
0 commit comments