Skip to content

Commit e904d10

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Introduce stateful metric protocol (#894)
Summary: Pull Request resolved: #894 Reviewed By: JKSenthil Differential Revision: D62395083 fbshipit-source-id: 3be009544368aec09a9bfcd9affea0a86a4dfaa8
1 parent 8ee0aa9 commit e904d10

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

torchtnt/utils/stateful.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ...
3636
class MultiStateful:
3737
"""
3838
Wrapper for multiple stateful objects. Necessary because we might have multiple nn.Modules or multiple optimizers,
39-
but save/load_checkpoint APIs may only accepts one stateful object.
39+
but save/load_checkpoint APIs may only accept one stateful object.
4040
4141
Stores state_dict as a dict of state_dicts.
4242
"""
@@ -55,3 +55,20 @@ def state_dict(self) -> Dict[str, Any]:
5555
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
5656
for k in state_dict:
5757
self.stateful_objs[k].load_state_dict(state_dict[k])
58+
59+
60+
@runtime_checkable
61+
class MetricStateful(Protocol):
62+
"""
63+
Defines the interfaces for metric objects that can be saved and loaded from checkpoints.
64+
This conforms to the API exposed by major metric libraries like torcheval.
65+
"""
66+
67+
def update(self, *_: Any, **__: Any) -> None: ...
68+
69+
# pyre-ignore[3]: Metric computation may return any type depending on the implementation
70+
def compute(self) -> Any: ...
71+
72+
def state_dict(self) -> Dict[str, Any]: ...
73+
74+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ...

0 commit comments

Comments
 (0)