Skip to content

Commit e43a764

Browse files
Update autoencoder.py
1 parent 8e07303 commit e43a764

File tree

1 file changed

+3
-20
lines changed

1 file changed

+3
-20
lines changed

sat/vae_modules/autoencoder.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -52,24 +52,6 @@ def __init__(
5252
if version.parse(torch.__version__) >= version.parse("2.0.0"):
5353
self.automatic_optimization = False
5454

55-
def apply_ckpt(self, ckpt: Union[None, str, dict]):
56-
if ckpt is None:
57-
return
58-
self.init_from_ckpt(ckpt)
59-
60-
def init_from_ckpt(self, path, ignore_keys=list()):
61-
sd = torch.load(path, map_location="cpu")["state_dict"]
62-
keys = list(sd.keys())
63-
for k in keys:
64-
for ik in ignore_keys:
65-
if k.startswith(ik):
66-
print("Deleting key {} from state_dict.".format(k))
67-
del sd[k]
68-
missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
69-
print("Missing keys: ", missing_keys)
70-
print("Unexpected keys: ", unexpected_keys)
71-
print(f"Restored from {path}")
72-
7355
def apply_ckpt(self, ckpt: Union[None, str, dict]):
7456
if ckpt is None:
7557
return
@@ -81,7 +63,6 @@ def apply_ckpt(self, ckpt: Union[None, str, dict]):
8163
engine = instantiate_from_config(ckpt)
8264
engine(self)
8365

84-
8566
@abstractmethod
8667
def get_input(self, batch) -> Any:
8768
raise NotImplementedError()
@@ -116,7 +97,9 @@ def decode(self, *args, **kwargs) -> torch.Tensor:
11697

11798
def instantiate_optimizer_from_config(self, params, lr, cfg):
11899
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
119-
return get_obj_from_str(cfg["target"])(params, lr=lr, **cfg.get("params", dict()))
100+
return get_obj_from_str(cfg["target"])(
101+
params, lr=lr, **cfg.get("params", dict())
102+
)
120103

121104
def configure_optimizers(self) -> Any:
122105
raise NotImplementedError()

0 commit comments

Comments
 (0)