Skip to content

Commit 72df3db

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
log intervals in callback handler (#943)
Summary: Pull Request resolved: #943 Reviewed By: diego-urgell Differential Revision: D65622296 fbshipit-source-id: ca9136e42045d12b23534afd900419cc413e8383
1 parent b442b1e commit 72df3db

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

torchtnt/framework/_callback_handler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torchtnt.framework.callback import Callback
1515
from torchtnt.framework.state import State
1616
from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit
17+
from torchtnt.utils.event_handlers import log_interval
1718

1819
logger: logging.Logger = logging.getLogger(__name__)
1920

@@ -124,6 +125,7 @@ def on_exception(
124125
for cb in callbacks:
125126
cb.on_exception(state, unit, exc)
126127

128+
@log_interval("on_train_start", {"category": "callback_handler"})
127129
def on_train_start(self, state: State, unit: TTrainUnit) -> None:
128130
fn_name = "on_train_start"
129131
callbacks = self._callbacks.get(fn_name, [])
@@ -176,12 +178,14 @@ def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
176178
for cb in callbacks:
177179
cb.on_train_step_end(state, unit)
178180

181+
@log_interval("on_train_epoch_end", {"category": "callback_handler"})
179182
def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None:
180183
fn_name = "on_train_epoch_end"
181184
callbacks = self._callbacks.get(fn_name, [])
182185
for cb in callbacks:
183186
cb.on_train_epoch_end(state, unit)
184187

188+
@log_interval("on_train_end", {"category": "callback_handler"})
185189
def on_train_end(self, state: State, unit: TTrainUnit) -> None:
186190
fn_name = "on_train_end"
187191
callbacks = self._callbacks.get(fn_name, [])

0 commit comments

Comments
 (0)