Skip to content

Commit 97b68cc

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Make AppStateMixin metric aware (#899)
Summary: Pull Request resolved: #899 Reviewed By: JKSenthil Differential Revision: D62555231 fbshipit-source-id: 2ec65e879f71d17c56871ef8a6bb1b706f7e1f62
1 parent e904d10 commit 97b68cc

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

torchtnt/framework/unit.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from torchtnt.utils.lr_scheduler import TLRScheduler
2323
from torchtnt.utils.prepare_module import _is_fsdp_module, FSDPOptimizerWrapper
2424
from torchtnt.utils.progress import Progress
25-
from torchtnt.utils.stateful import Stateful
25+
from torchtnt.utils.stateful import MetricStateful, Stateful
2626

2727

2828
_logger: logging.Logger = logging.getLogger(__name__)
@@ -51,6 +51,7 @@ def __init__(self) -> None:
5151
self._optimizers: Dict[str, torch.optim.Optimizer] = {}
5252
self._lr_schedulers: Dict[str, TLRScheduler] = {}
5353
self._progress: Dict[str, Progress] = {}
54+
self._metrics: Dict[str, MetricStateful] = {}
5455
# catch-all for miscellaneous statefuls
5556
self._misc_statefuls: Dict[str, Any] = {}
5657
# TODO: include other known statefuls
@@ -67,6 +68,7 @@ def app_state(self) -> Dict[str, Any]:
6768
**self.tracked_lr_schedulers(),
6869
**self.tracked_progress(),
6970
**self.tracked_misc_statefuls(),
71+
**self.tracked_metrics(),
7072
}
7173
return app_state
7274

@@ -84,6 +86,9 @@ def tracked_lr_schedulers(
8486
def tracked_progress(self) -> Dict[str, Progress]:
8587
return self._progress
8688

89+
def tracked_metrics(self) -> Dict[str, MetricStateful]:
90+
return self._metrics
91+
8792
def tracked_misc_statefuls(self) -> Dict[str, Any]:
8893
return self._misc_statefuls
8994

@@ -104,6 +109,10 @@ def __getattr__(self, name: str) -> object:
104109
_progress = self.__dict__["_progress"]
105110
if name in _progress:
106111
return _progress[name]
112+
if "_metrics" in self.__dict__:
113+
_metrics = self.__dict__["_metrics"]
114+
if name in _metrics:
115+
return _metrics[name]
107116
if "_misc_statefuls" in self.__dict__:
108117
_misc_statefuls = self.__dict__["_misc_statefuls"]
109118
if name in _misc_statefuls:
@@ -128,12 +137,16 @@ def _update_attr(
128137
self._optimizers,
129138
self._lr_schedulers,
130139
self._progress,
140+
self._metrics,
131141
self._misc_statefuls,
132142
)
133143
tracked_objects[name] = value
134144

135145
def __setattr__(self, name: str, value: object) -> None:
136-
if isinstance(value, torch.nn.Module):
146+
# Check first for metrics since some libraries subclass nn.Module as well
147+
if isinstance(value, MetricStateful):
148+
self._update_attr(name, value, self.__dict__.get("_metrics"))
149+
elif isinstance(value, torch.nn.Module):
137150
self._update_attr(name, value, self.__dict__.get("_modules"))
138151
elif isinstance(value, torch.optim.Optimizer):
139152
self._update_attr(name, value, self.__dict__.get("_optimizers"))
@@ -163,6 +176,7 @@ def __setattr__(self, name: str, value: object) -> None:
163176
self._modules,
164177
self._optimizers,
165178
self._lr_schedulers,
179+
self._metrics,
166180
self._misc_statefuls,
167181
)
168182
super().__setattr__(name, value)
@@ -176,6 +190,8 @@ def __delattr__(self, name: str) -> None:
176190
del self._lr_schedulers[name]
177191
elif name in self._progress:
178192
del self._progress[name]
193+
elif name in self._metrics:
194+
del self._metrics[name]
179195
elif name in self._misc_statefuls:
180196
del self._misc_statefuls[name]
181197
else:

0 commit comments

Comments
 (0)