Skip to content

Commit 6e5b158

Browse files
Rabia Shakoorfacebook-github-bot
authored andcommitted
added a stateful dictionary (#983)
Summary: Pull Request resolved: #983 Added a stateful dictionary object that implements the stateful interface for checkpoint saving and loading and also extends dict. This object can be used to store key/value pairs of arbitrary value objects (for example, computed metrics or any other object that we require to be saved to the unit state) Reviewed By: diego-urgell Differential Revision: D71051443 fbshipit-source-id: 7a94ae43641a0bc3d93690f015bfa75ca279362b
1 parent 4059cc4 commit 6e5b158

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

torchtnt/utils/stateful.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,14 @@ def compute(self) -> Any: ...
7575
def state_dict(self) -> Dict[str, Any]: ...
7676

7777
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ...
78+
79+
80+
class DictStateful(Stateful, Dict[str, Any]):
81+
"""A dictionary that implements the stateful interface that can be saved and loaded from checkpoints."""
82+
83+
def state_dict(self) -> Dict[str, Any]:
84+
return self
85+
86+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
87+
self.clear()
88+
self.update(state_dict)

0 commit comments

Comments
 (0)