Skip to content

Commit 8150bcc

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Add more dataloader hooks to the Callback interface (#937)
Summary: Pull Request resolved: #937 Reviewed By: JKSenthil Differential Revision: D64909630 fbshipit-source-id: b9fe4b8bf736505bf92e162fb6bbd96999988c46
1 parent 3bc2dfd commit 8150bcc

File tree

7 files changed

+234
-2
lines changed

7 files changed

+234
-2
lines changed

tests/framework/test_callback_handler.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,19 @@ def on_train_start(self, state: State, unit: TTrainUnit) -> None:
4646
def on_train_epoch_start(self, state: State, unit: TTrainUnit) -> None:
4747
self.called_hooks.add("on_train_epoch_start")
4848

49+
def on_train_dataloader_iter_creation_start(
50+
self, state: State, unit: TTrainUnit
51+
) -> None:
52+
self.called_hooks.add("on_train_dataloader_iter_creation_start")
53+
54+
def on_train_dataloader_iter_creation_end(
55+
self, state: State, unit: TTrainUnit
56+
) -> None:
57+
self.called_hooks.add("on_train_dataloader_iter_creation_end")
58+
59+
def on_train_get_next_batch_start(self, state: State, unit: TTrainUnit) -> None:
60+
self.called_hooks.add("on_train_get_next_batch_start")
61+
4962
def on_train_get_next_batch_end(self, state: State, unit: TTrainUnit) -> None:
5063
self.called_hooks.add("on_train_get_next_batch_end")
5164

@@ -67,6 +80,19 @@ def on_eval_start(self, state: State, unit: TEvalUnit) -> None:
6780
def on_eval_epoch_start(self, state: State, unit: TEvalUnit) -> None:
6881
self.called_hooks.add("on_eval_epoch_start")
6982

83+
def on_eval_dataloader_iter_creation_start(
84+
self, state: State, unit: TEvalUnit
85+
) -> None:
86+
self.called_hooks.add("on_eval_dataloader_iter_creation_start")
87+
88+
def on_eval_dataloader_iter_creation_end(
89+
self, state: State, unit: TEvalUnit
90+
) -> None:
91+
self.called_hooks.add("on_eval_dataloader_iter_creation_end")
92+
93+
def on_eval_get_next_batch_start(self, state: State, unit: TEvalUnit) -> None:
94+
self.called_hooks.add("on_eval_get_next_batch_start")
95+
7096
def on_eval_get_next_batch_end(self, state: State, unit: TEvalUnit) -> None:
7197
self.called_hooks.add("on_eval_get_next_batch_end")
7298

@@ -85,6 +111,19 @@ def on_eval_end(self, state: State, unit: TEvalUnit) -> None:
85111
def on_predict_start(self, state: State, unit: TPredictUnit) -> None:
86112
self.called_hooks.add("on_predict_start")
87113

114+
def on_predict_dataloader_iter_creation_start(
115+
self, state: State, unit: TPredictUnit
116+
) -> None:
117+
self.called_hooks.add("on_predict_dataloader_iter_creation_start")
118+
119+
def on_predict_dataloader_iter_creation_end(
120+
self, state: State, unit: TPredictUnit
121+
) -> None:
122+
self.called_hooks.add("on_predict_dataloader_iter_creation_end")
123+
124+
def on_predict_get_next_batch_start(self, state: State, unit: TPredictUnit) -> None:
125+
self.called_hooks.add("on_predict_get_next_batch_start")
126+
88127
def on_predict_epoch_start(self, state: State, unit: TPredictUnit) -> None:
89128
self.called_hooks.add("on_predict_epoch_start")
90129

@@ -129,6 +168,15 @@ def test_callback_handler(self) -> None:
129168
cb_handler.on_train_epoch_start(state, unit)
130169
self.assertIn("on_train_epoch_start", called_hooks)
131170

171+
cb_handler.on_train_dataloader_iter_creation_start(state, unit)
172+
self.assertIn("on_train_dataloader_iter_creation_start", called_hooks)
173+
174+
cb_handler.on_train_dataloader_iter_creation_end(state, unit)
175+
self.assertIn("on_train_dataloader_iter_creation_end", called_hooks)
176+
177+
cb_handler.on_train_get_next_batch_start(state, unit)
178+
self.assertIn("on_train_get_next_batch_start", called_hooks)
179+
132180
cb_handler.on_train_get_next_batch_end(state, unit)
133181
self.assertIn("on_train_get_next_batch_end", called_hooks)
134182

@@ -154,6 +202,15 @@ def test_callback_handler(self) -> None:
154202
cb_handler.on_eval_epoch_start(state, unit)
155203
self.assertIn("on_eval_epoch_start", called_hooks)
156204

205+
cb_handler.on_eval_dataloader_iter_creation_start(state, unit)
206+
self.assertIn("on_eval_dataloader_iter_creation_start", called_hooks)
207+
208+
cb_handler.on_eval_dataloader_iter_creation_end(state, unit)
209+
self.assertIn("on_eval_dataloader_iter_creation_end", called_hooks)
210+
211+
cb_handler.on_eval_get_next_batch_start(state, unit)
212+
self.assertIn("on_eval_get_next_batch_start", called_hooks)
213+
157214
cb_handler.on_eval_get_next_batch_end(state, unit)
158215
self.assertIn("on_eval_get_next_batch_end", called_hooks)
159216

@@ -179,6 +236,15 @@ def test_callback_handler(self) -> None:
179236
cb_handler.on_predict_epoch_start(state, unit)
180237
self.assertIn("on_predict_epoch_start", called_hooks)
181238

239+
cb_handler.on_predict_dataloader_iter_creation_start(state, unit)
240+
self.assertIn("on_predict_dataloader_iter_creation_start", called_hooks)
241+
242+
cb_handler.on_predict_dataloader_iter_creation_end(state, unit)
243+
self.assertIn("on_predict_dataloader_iter_creation_end", called_hooks)
244+
245+
cb_handler.on_predict_get_next_batch_start(state, unit)
246+
self.assertIn("on_predict_get_next_batch_start", called_hooks)
247+
182248
cb_handler.on_predict_get_next_batch_end(state, unit)
183249
self.assertIn("on_predict_get_next_batch_end", called_hooks)
184250

@@ -202,20 +268,29 @@ def test_get_implemented_callback_mapping(self) -> None:
202268
remaining_callback_hooks = (
203269
"on_train_start",
204270
"on_train_epoch_start",
271+
"on_train_dataloader_iter_creation_start",
272+
"on_train_dataloader_iter_creation_end",
273+
"on_train_get_next_batch_start",
205274
"on_train_get_next_batch_end",
206275
"on_train_step_start",
207276
"on_train_step_end",
208277
"on_train_epoch_end",
209278
"on_train_end",
210279
"on_eval_start",
211280
"on_eval_epoch_start",
281+
"on_eval_dataloader_iter_creation_start",
282+
"on_eval_dataloader_iter_creation_end",
283+
"on_eval_get_next_batch_start",
212284
"on_eval_get_next_batch_end",
213285
"on_eval_step_start",
214286
"on_eval_step_end",
215287
"on_eval_epoch_end",
216288
"on_eval_end",
217289
"on_predict_start",
218290
"on_predict_epoch_start",
291+
"on_predict_dataloader_iter_creation_start",
292+
"on_predict_dataloader_iter_creation_end",
293+
"on_predict_get_next_batch_start",
219294
"on_predict_get_next_batch_end",
220295
"on_predict_step_start",
221296
"on_predict_step_end",

torchtnt/framework/_callback_handler.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,20 +63,29 @@ def _get_implemented_callback_mapping(
6363
"on_exception",
6464
"on_train_start",
6565
"on_train_epoch_start",
66+
"on_train_dataloader_iter_creation_start",
67+
"on_train_dataloader_iter_creation_end",
68+
"on_train_get_next_batch_start",
6669
"on_train_get_next_batch_end",
6770
"on_train_step_start",
6871
"on_train_step_end",
6972
"on_train_epoch_end",
7073
"on_train_end",
7174
"on_eval_start",
7275
"on_eval_epoch_start",
76+
"on_eval_dataloader_iter_creation_start",
77+
"on_eval_dataloader_iter_creation_end",
78+
"on_eval_get_next_batch_start",
7379
"on_eval_get_next_batch_end",
7480
"on_eval_step_start",
7581
"on_eval_step_end",
7682
"on_eval_epoch_end",
7783
"on_eval_end",
7884
"on_predict_start",
7985
"on_predict_epoch_start",
86+
"on_predict_dataloader_iter_creation_start",
87+
"on_predict_dataloader_iter_creation_end",
88+
"on_predict_get_next_batch_start",
8089
"on_predict_get_next_batch_end",
8190
"on_predict_step_start",
8291
"on_predict_step_end",
@@ -127,6 +136,28 @@ def on_train_epoch_start(self, state: State, unit: TTrainUnit) -> None:
127136
for cb in callbacks:
128137
cb.on_train_epoch_start(state, unit)
129138

139+
def on_train_dataloader_iter_creation_start(
140+
self, state: State, unit: TTrainUnit
141+
) -> None:
142+
fn_name = "on_train_dataloader_iter_creation_start"
143+
callbacks = self._callbacks.get(fn_name, [])
144+
for cb in callbacks:
145+
cb.on_train_dataloader_iter_creation_start(state, unit)
146+
147+
def on_train_dataloader_iter_creation_end(
148+
self, state: State, unit: TTrainUnit
149+
) -> None:
150+
fn_name = "on_train_dataloader_iter_creation_end"
151+
callbacks = self._callbacks.get(fn_name, [])
152+
for cb in callbacks:
153+
cb.on_train_dataloader_iter_creation_end(state, unit)
154+
155+
def on_train_get_next_batch_start(self, state: State, unit: TTrainUnit) -> None:
156+
fn_name = "on_train_get_next_batch_start"
157+
callbacks = self._callbacks.get(fn_name, [])
158+
for cb in callbacks:
159+
cb.on_train_get_next_batch_start(state, unit)
160+
130161
def on_train_get_next_batch_end(self, state: State, unit: TTrainUnit) -> None:
131162
fn_name = "on_train_get_next_batch_end"
132163
callbacks = self._callbacks.get(fn_name, [])
@@ -169,6 +200,28 @@ def on_eval_epoch_start(self, state: State, unit: TEvalUnit) -> None:
169200
for cb in callbacks:
170201
cb.on_eval_epoch_start(state, unit)
171202

203+
def on_eval_dataloader_iter_creation_start(
204+
self, state: State, unit: TEvalUnit
205+
) -> None:
206+
fn_name = "on_eval_dataloader_iter_creation_start"
207+
callbacks = self._callbacks.get(fn_name, [])
208+
for cb in callbacks:
209+
cb.on_eval_dataloader_iter_creation_start(state, unit)
210+
211+
def on_eval_dataloader_iter_creation_end(
212+
self, state: State, unit: TEvalUnit
213+
) -> None:
214+
fn_name = "on_eval_dataloader_iter_creation_end"
215+
callbacks = self._callbacks.get(fn_name, [])
216+
for cb in callbacks:
217+
cb.on_eval_dataloader_iter_creation_end(state, unit)
218+
219+
def on_eval_get_next_batch_start(self, state: State, unit: TEvalUnit) -> None:
220+
fn_name = "on_eval_get_next_batch_start"
221+
callbacks = self._callbacks.get(fn_name, [])
222+
for cb in callbacks:
223+
cb.on_eval_get_next_batch_start(state, unit)
224+
172225
def on_eval_get_next_batch_end(self, state: State, unit: TEvalUnit) -> None:
173226
fn_name = "on_eval_get_next_batch_end"
174227
callbacks = self._callbacks.get(fn_name, [])
@@ -211,6 +264,28 @@ def on_predict_epoch_start(self, state: State, unit: TPredictUnit) -> None:
211264
for cb in callbacks:
212265
cb.on_predict_epoch_start(state, unit)
213266

267+
def on_predict_dataloader_iter_creation_start(
268+
self, state: State, unit: TPredictUnit
269+
) -> None:
270+
fn_name = "on_predict_dataloader_iter_creation_start"
271+
callbacks = self._callbacks.get(fn_name, [])
272+
for cb in callbacks:
273+
cb.on_predict_dataloader_iter_creation_start(state, unit)
274+
275+
def on_predict_dataloader_iter_creation_end(
276+
self, state: State, unit: TPredictUnit
277+
) -> None:
278+
fn_name = "on_predict_dataloader_iter_creation_end"
279+
callbacks = self._callbacks.get(fn_name, [])
280+
for cb in callbacks:
281+
cb.on_predict_dataloader_iter_creation_end(state, unit)
282+
283+
def on_predict_get_next_batch_start(self, state: State, unit: TPredictUnit) -> None:
284+
fn_name = "on_predict_get_next_batch_start"
285+
callbacks = self._callbacks.get(fn_name, [])
286+
for cb in callbacks:
287+
cb.on_predict_get_next_batch_start(state, unit)
288+
214289
def on_predict_get_next_batch_end(self, state: State, unit: TPredictUnit) -> None:
215290
fn_name = "on_predict_get_next_batch_end"
216291
callbacks = self._callbacks.get(fn_name, [])

torchtnt/framework/callback.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,22 @@ def on_train_epoch_start(self, state: State, unit: TTrainUnit) -> None:
7777
"""Hook called before a new train epoch starts."""
7878
pass
7979

80+
def on_train_dataloader_iter_creation_start(
81+
self, state: State, unit: TTrainUnit
82+
) -> None:
83+
"""Hook called before the dataloader iterator is created."""
84+
pass
85+
86+
def on_train_dataloader_iter_creation_end(
87+
self, state: State, unit: TTrainUnit
88+
) -> None:
89+
"""Hook called after the dataloader iterator is created."""
90+
pass
91+
92+
def on_train_get_next_batch_start(self, state: State, unit: TTrainUnit) -> None:
93+
"""Hook called before getting the data batch for the next train step."""
94+
pass
95+
8096
def on_train_get_next_batch_end(self, state: State, unit: TTrainUnit) -> None:
8197
"""Hook called after getting the data batch for the next train step."""
8298
pass
@@ -105,6 +121,22 @@ def on_eval_epoch_start(self, state: State, unit: TEvalUnit) -> None:
105121
"""Hook called before a new eval epoch starts."""
106122
pass
107123

124+
def on_eval_dataloader_iter_creation_start(
125+
self, state: State, unit: TEvalUnit
126+
) -> None:
127+
"""Hook called before the dataloader iterator is created."""
128+
pass
129+
130+
def on_eval_dataloader_iter_creation_end(
131+
self, state: State, unit: TEvalUnit
132+
) -> None:
133+
"""Hook called after the dataloader iterator is created."""
134+
pass
135+
136+
def on_eval_get_next_batch_start(self, state: State, unit: TEvalUnit) -> None:
137+
"""Hook called before getting the data batch for the next eval step."""
138+
pass
139+
108140
def on_eval_get_next_batch_end(self, state: State, unit: TEvalUnit) -> None:
109141
"""Hook called after getting the data batch for the next eval step."""
110142
pass
@@ -133,6 +165,22 @@ def on_predict_epoch_start(self, state: State, unit: TPredictUnit) -> None:
133165
"""Hook called before a new predict epoch starts."""
134166
pass
135167

168+
def on_predict_dataloader_iter_creation_start(
169+
self, state: State, unit: TPredictUnit
170+
) -> None:
171+
"""Hook called before the dataloader iterator is created."""
172+
pass
173+
174+
def on_predict_dataloader_iter_creation_end(
175+
self, state: State, unit: TPredictUnit
176+
) -> None:
177+
"""Hook called after the dataloader iterator is created."""
178+
pass
179+
180+
def on_predict_get_next_batch_start(self, state: State, unit: TPredictUnit) -> None:
181+
"""Hook called before getting the data batch for the next predict step."""
182+
pass
183+
136184
def on_predict_get_next_batch_end(self, state: State, unit: TPredictUnit) -> None:
137185
"""Hook called after getting the data batch for the next predict step."""
138186
pass

torchtnt/framework/callbacks/lambda_callback.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,15 @@ def __init__(
8282
] = None,
8383
on_train_start: Optional[Callable[[State, TTrainUnit], None]] = None,
8484
on_train_epoch_start: Optional[Callable[[State, TTrainUnit], None]] = None,
85+
on_train_dataloader_iter_creation_start: Optional[
86+
Callable[[State, TTrainUnit], None]
87+
] = None,
88+
on_train_dataloader_iter_creation_end: Optional[
89+
Callable[[State, TTrainUnit], None]
90+
] = None,
91+
on_train_get_next_batch_start: Optional[
92+
Callable[[State, TTrainUnit], None]
93+
] = None,
8594
on_train_get_next_batch_end: Optional[
8695
Callable[[State, TTrainUnit], None]
8796
] = None,
@@ -91,13 +100,31 @@ def __init__(
91100
on_train_end: Optional[Callable[[State, TTrainUnit], None]] = None,
92101
on_eval_start: Optional[Callable[[State, TEvalUnit], None]] = None,
93102
on_eval_epoch_start: Optional[Callable[[State, TEvalUnit], None]] = None,
103+
on_eval_dataloader_iter_creation_start: Optional[
104+
Callable[[State, TTrainUnit], None]
105+
] = None,
106+
on_eval_dataloader_iter_creation_end: Optional[
107+
Callable[[State, TTrainUnit], None]
108+
] = None,
109+
on_eval_get_next_batch_start: Optional[
110+
Callable[[State, TTrainUnit], None]
111+
] = None,
94112
on_eval_get_next_batch_end: Optional[Callable[[State, TEvalUnit], None]] = None,
95113
on_eval_step_start: Optional[Callable[[State, TEvalUnit], None]] = None,
96114
on_eval_step_end: Optional[Callable[[State, TEvalUnit], None]] = None,
97115
on_eval_epoch_end: Optional[Callable[[State, TEvalUnit], None]] = None,
98116
on_eval_end: Optional[Callable[[State, TEvalUnit], None]] = None,
99117
on_predict_start: Optional[Callable[[State, TPredictUnit], None]] = None,
100118
on_predict_epoch_start: Optional[Callable[[State, TPredictUnit], None]] = None,
119+
on_predict_dataloader_iter_creation_start: Optional[
120+
Callable[[State, TTrainUnit], None]
121+
] = None,
122+
on_predict_dataloader_iter_creation_end: Optional[
123+
Callable[[State, TTrainUnit], None]
124+
] = None,
125+
on_predict_get_next_batch_start: Optional[
126+
Callable[[State, TTrainUnit], None]
127+
] = None,
101128
on_predict_get_next_batch_end: Optional[
102129
Callable[[State, TPredictUnit], None]
103130
] = None,

torchtnt/framework/evaluate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,10 @@ def _evaluate_impl(
132132
eval_unit.on_eval_epoch_start(state)
133133
callback_handler.on_eval_epoch_start(state, eval_unit)
134134

135+
callback_handler.on_eval_dataloader_iter_creation_start(state, eval_unit)
135136
with get_timing_context(state, "evaluate.iter(dataloader)"):
136137
data_iter = iter(eval_state.dataloader)
137-
step_input = data_iter
138+
callback_handler.on_eval_dataloader_iter_creation_end(state, eval_unit)
138139

139140
prev_steps_in_epoch = eval_unit.eval_progress.num_steps_completed_in_epoch
140141

@@ -151,6 +152,7 @@ def _evaluate_impl(
151152
with get_timing_context(
152153
state, "evaluate.next(data_iter)"
153154
), eval_state.iteration_timer.time("data_wait_time"):
155+
callback_handler.on_eval_get_next_batch_start(state, eval_unit)
154156
step_input = eval_unit.get_next_eval_batch(state, data_iter)
155157
callback_handler.on_eval_get_next_batch_end(state, eval_unit)
156158

0 commit comments

Comments
 (0)