7
7
# pyre-strict
8
8
9
9
10
- from typing import Optional , Union
10
+ from typing import cast , Optional , Union
11
11
12
12
from pyre_extensions import none_throws
13
13
from torch .utils .tensorboard import SummaryWriter
14
14
15
15
from torchtnt .framework .callback import Callback
16
16
from torchtnt .framework .state import State
17
17
from torchtnt .framework .unit import TEvalUnit , TPredictUnit , TTrainUnit
18
- from torchtnt .utils .distributed import get_global_rank
19
- from torchtnt .utils .loggers .tensorboard import TensorBoardLogger
18
+ from torchtnt .utils .distributed import rank_zero_fn
19
+ from torchtnt .utils .loggers .logger import MetricLogger
20
20
from torchtnt .utils .timer import TimerProtocol
21
21
22
22
@@ -35,23 +35,17 @@ class IterationTimeLogger(Callback):
35
35
36
36
def __init__ (
37
37
self ,
38
- logger : Union [TensorBoardLogger , SummaryWriter ],
38
+ logger : Union [MetricLogger , SummaryWriter ],
39
39
moving_avg_window : int = 1 ,
40
40
log_every_n_steps : int = 1 ,
41
41
) -> None :
42
- if isinstance (logger , TensorBoardLogger ):
43
- logger = logger .writer
44
-
45
- if get_global_rank () == 0 : # only write from the main rank
46
- self ._writer = none_throws (
47
- logger , "TensorBoardLogger.writer should not be None"
48
- )
42
+ self ._logger = logger
49
43
self .moving_avg_window = moving_avg_window
50
44
self .log_every_n_steps = log_every_n_steps
51
45
46
+ @rank_zero_fn
52
47
def _log_step_metrics (
53
48
self ,
54
- writer : SummaryWriter ,
55
49
metric_label : str ,
56
50
iteration_timer : TimerProtocol ,
57
51
step_logging_for : int ,
@@ -75,38 +69,39 @@ def _log_step_metrics(
75
69
return
76
70
77
71
last_n_values = time_list [- self .moving_avg_window :]
78
- writer .add_scalar (
79
- human_metric_names [metric_label ],
80
- sum (last_n_values ) / len (last_n_values ),
81
- step_logging_for ,
82
- )
72
+ if isinstance (self ._logger , SummaryWriter ):
73
+ self ._logger .add_scalar (
74
+ human_metric_names [metric_label ],
75
+ sum (last_n_values ) / len (last_n_values ),
76
+ step_logging_for ,
77
+ )
78
+ else :
79
+ cast (MetricLogger , self ._logger ).log (
80
+ human_metric_names [metric_label ],
81
+ sum (last_n_values ) / len (last_n_values ),
82
+ step_logging_for ,
83
+ )
83
84
84
85
def on_train_step_end (self , state : State , unit : TTrainUnit ) -> None :
85
86
timer = none_throws (state .train_state ).iteration_timer
86
- if writer := self ._writer :
87
- self ._log_step_metrics (
88
- writer ,
89
- "train_iteration_time" ,
90
- timer ,
91
- unit .train_progress .num_steps_completed ,
92
- )
87
+ self ._log_step_metrics (
88
+ "train_iteration_time" ,
89
+ timer ,
90
+ unit .train_progress .num_steps_completed ,
91
+ )
93
92
94
93
def on_eval_step_end (self , state : State , unit : TEvalUnit ) -> None :
95
94
timer = none_throws (state .eval_state ).iteration_timer
96
- if writer := self ._writer :
97
- self ._log_step_metrics (
98
- writer ,
99
- "eval_iteration_time" ,
100
- timer ,
101
- unit .eval_progress .num_steps_completed ,
102
- )
95
+ self ._log_step_metrics (
96
+ "eval_iteration_time" ,
97
+ timer ,
98
+ unit .eval_progress .num_steps_completed ,
99
+ )
103
100
104
101
def on_predict_step_end (self , state : State , unit : TPredictUnit ) -> None :
105
102
timer = none_throws (state .predict_state ).iteration_timer
106
- if writer := self ._writer :
107
- self ._log_step_metrics (
108
- writer ,
109
- "predict_iteration_time" ,
110
- timer ,
111
- unit .predict_progress .num_steps_completed ,
112
- )
103
+ self ._log_step_metrics (
104
+ "predict_iteration_time" ,
105
+ timer ,
106
+ unit .predict_progress .num_steps_completed ,
107
+ )
0 commit comments