|
14 | 14 | from torchtnt.framework.callback import Callback
|
15 | 15 | from torchtnt.framework.state import State
|
16 | 16 | from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit
|
| 17 | +from torchtnt.utils.event_handlers import log_interval |
17 | 18 |
|
18 | 19 | logger: logging.Logger = logging.getLogger(__name__)
|
19 | 20 |
|
@@ -124,6 +125,7 @@ def on_exception(
|
124 | 125 | for cb in callbacks:
|
125 | 126 | cb.on_exception(state, unit, exc)
|
126 | 127 |
|
| 128 | + @log_interval("on_train_start", {"category": "callback_handler"}) |
127 | 129 | def on_train_start(self, state: State, unit: TTrainUnit) -> None:
|
128 | 130 | fn_name = "on_train_start"
|
129 | 131 | callbacks = self._callbacks.get(fn_name, [])
|
@@ -176,12 +178,14 @@ def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
|
176 | 178 | for cb in callbacks:
|
177 | 179 | cb.on_train_step_end(state, unit)
|
178 | 180 |
|
| 181 | + @log_interval("on_train_epoch_end", {"category": "callback_handler"}) |
179 | 182 | def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None:
|
180 | 183 | fn_name = "on_train_epoch_end"
|
181 | 184 | callbacks = self._callbacks.get(fn_name, [])
|
182 | 185 | for cb in callbacks:
|
183 | 186 | cb.on_train_epoch_end(state, unit)
|
184 | 187 |
|
| 188 | + @log_interval("on_train_end", {"category": "callback_handler"}) |
185 | 189 | def on_train_end(self, state: State, unit: TTrainUnit) -> None:
|
186 | 190 | fn_name = "on_train_end"
|
187 | 191 | callbacks = self._callbacks.get(fn_name, [])
|
|
0 commit comments