22
22
from torchtnt .utils .lr_scheduler import TLRScheduler
23
23
from torchtnt .utils .prepare_module import _is_fsdp_module , FSDPOptimizerWrapper
24
24
from torchtnt .utils .progress import Progress
25
- from torchtnt .utils .stateful import Stateful
25
+ from torchtnt .utils .stateful import MetricStateful , Stateful
26
26
27
27
28
28
_logger : logging .Logger = logging .getLogger (__name__ )
@@ -51,6 +51,7 @@ def __init__(self) -> None:
51
51
self ._optimizers : Dict [str , torch .optim .Optimizer ] = {}
52
52
self ._lr_schedulers : Dict [str , TLRScheduler ] = {}
53
53
self ._progress : Dict [str , Progress ] = {}
54
+ self ._metrics : Dict [str , MetricStateful ] = {}
54
55
# catch-all for miscellaneous statefuls
55
56
self ._misc_statefuls : Dict [str , Any ] = {}
56
57
# TODO: include other known statefuls
@@ -67,6 +68,7 @@ def app_state(self) -> Dict[str, Any]:
67
68
** self .tracked_lr_schedulers (),
68
69
** self .tracked_progress (),
69
70
** self .tracked_misc_statefuls (),
71
+ ** self .tracked_metrics (),
70
72
}
71
73
return app_state
72
74
@@ -84,6 +86,9 @@ def tracked_lr_schedulers(
84
86
def tracked_progress (self ) -> Dict [str , Progress ]:
85
87
return self ._progress
86
88
89
+ def tracked_metrics (self ) -> Dict [str , MetricStateful ]:
90
+ return self ._metrics
91
+
87
92
def tracked_misc_statefuls (self ) -> Dict [str , Any ]:
88
93
return self ._misc_statefuls
89
94
@@ -104,6 +109,10 @@ def __getattr__(self, name: str) -> object:
104
109
_progress = self .__dict__ ["_progress" ]
105
110
if name in _progress :
106
111
return _progress [name ]
112
+ if "_metrics" in self .__dict__ :
113
+ _metrics = self .__dict__ ["_metrics" ]
114
+ if name in _metrics :
115
+ return _metrics [name ]
107
116
if "_misc_statefuls" in self .__dict__ :
108
117
_misc_statefuls = self .__dict__ ["_misc_statefuls" ]
109
118
if name in _misc_statefuls :
@@ -128,12 +137,16 @@ def _update_attr(
128
137
self ._optimizers ,
129
138
self ._lr_schedulers ,
130
139
self ._progress ,
140
+ self ._metrics ,
131
141
self ._misc_statefuls ,
132
142
)
133
143
tracked_objects [name ] = value
134
144
135
145
def __setattr__ (self , name : str , value : object ) -> None :
136
- if isinstance (value , torch .nn .Module ):
146
+ # Check first for metrics since some libraries subclass nn.Module as well
147
+ if isinstance (value , MetricStateful ):
148
+ self ._update_attr (name , value , self .__dict__ .get ("_metrics" ))
149
+ elif isinstance (value , torch .nn .Module ):
137
150
self ._update_attr (name , value , self .__dict__ .get ("_modules" ))
138
151
elif isinstance (value , torch .optim .Optimizer ):
139
152
self ._update_attr (name , value , self .__dict__ .get ("_optimizers" ))
@@ -163,6 +176,7 @@ def __setattr__(self, name: str, value: object) -> None:
163
176
self ._modules ,
164
177
self ._optimizers ,
165
178
self ._lr_schedulers ,
179
+ self ._metrics ,
166
180
self ._misc_statefuls ,
167
181
)
168
182
super ().__setattr__ (name , value )
@@ -176,6 +190,8 @@ def __delattr__(self, name: str) -> None:
176
190
del self ._lr_schedulers [name ]
177
191
elif name in self ._progress :
178
192
del self ._progress [name ]
193
+ elif name in self ._metrics :
194
+ del self ._metrics [name ]
179
195
elif name in self ._misc_statefuls :
180
196
del self ._misc_statefuls [name ]
181
197
else :
0 commit comments