File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -84,7 +84,7 @@ class ModelWrapper(Stateful):
84
84
def __init__ (self , model : Union [nn .Module , List [nn .Module ]]) -> None :
85
85
self .model = [model ] if isinstance (model , nn .Module ) else model
86
86
87
- def state_dict (self ) -> None :
87
+ def state_dict (self ) -> Dict [ str , Any ] :
88
88
return {
89
89
k : v for sd in map (get_model_state_dict , self .model ) for k , v in sd .items ()
90
90
}
@@ -107,7 +107,7 @@ def __init__(
107
107
self .model = [model ] if isinstance (model , nn .Module ) else model
108
108
self .optim = [optim ] if isinstance (optim , torch .optim .Optimizer ) else optim
109
109
110
- def state_dict (self ) -> None :
110
+ def state_dict (self ) -> Dict [ str , Any ] :
111
111
func = functools .partial (
112
112
get_optimizer_state_dict ,
113
113
options = StateDictOptions (flatten_optimizer_state_dict = True ),
You can’t perform that action at this time.
0 commit comments