Skip to content

Commit eef8bb2

Browse files
authored
chore: fixing type hint checkpointing class (#590)
This PR fixes the type hint of some classes in the checkpointing code.
1 parent 3d82ca6 commit eef8bb2

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torchtitan/checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class ModelWrapper(Stateful):
8484
def __init__(self, model: Union[nn.Module, List[nn.Module]]) -> None:
8585
self.model = [model] if isinstance(model, nn.Module) else model
8686

87-
def state_dict(self) -> None:
87+
def state_dict(self) -> Dict[str, Any]:
8888
return {
8989
k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()
9090
}
@@ -107,7 +107,7 @@ def __init__(
107107
self.model = [model] if isinstance(model, nn.Module) else model
108108
self.optim = [optim] if isinstance(optim, torch.optim.Optimizer) else optim
109109

110-
def state_dict(self) -> None:
110+
def state_dict(self) -> Dict[str, Any]:
111111
func = functools.partial(
112112
get_optimizer_state_dict,
113113
options=StateDictOptions(flatten_optimizer_state_dict=True),

0 commit comments

Comments
 (0)