@@ -52,24 +52,6 @@ def __init__(
5252 if version .parse (torch .__version__ ) >= version .parse ("2.0.0" ):
5353 self .automatic_optimization = False
5454
55- def apply_ckpt (self , ckpt : Union [None , str , dict ]):
56- if ckpt is None :
57- return
58- self .init_from_ckpt (ckpt )
59-
60- def init_from_ckpt (self , path , ignore_keys = list ()):
61- sd = torch .load (path , map_location = "cpu" )["state_dict" ]
62- keys = list (sd .keys ())
63- for k in keys :
64- for ik in ignore_keys :
65- if k .startswith (ik ):
66- print ("Deleting key {} from state_dict." .format (k ))
67- del sd [k ]
68- missing_keys , unexpected_keys = self .load_state_dict (sd , strict = False )
69- print ("Missing keys: " , missing_keys )
70- print ("Unexpected keys: " , unexpected_keys )
71- print (f"Restored from { path } " )
72-
7355 def apply_ckpt (self , ckpt : Union [None , str , dict ]):
7456 if ckpt is None :
7557 return
@@ -81,7 +63,6 @@ def apply_ckpt(self, ckpt: Union[None, str, dict]):
8163 engine = instantiate_from_config (ckpt )
8264 engine (self )
8365
84-
8566 @abstractmethod
8667 def get_input (self , batch ) -> Any :
8768 raise NotImplementedError ()
@@ -116,7 +97,9 @@ def decode(self, *args, **kwargs) -> torch.Tensor:
11697
11798 def instantiate_optimizer_from_config (self , params , lr , cfg ):
11899 logpy .info (f"loading >>> { cfg ['target' ]} <<< optimizer from config" )
119- return get_obj_from_str (cfg ["target" ])(params , lr = lr , ** cfg .get ("params" , dict ()))
100+ return get_obj_from_str (cfg ["target" ])(
101+ params , lr = lr , ** cfg .get ("params" , dict ())
102+ )
120103
121104 def configure_optimizers (self ) -> Any :
122105 raise NotImplementedError ()
0 commit comments