Skip to content

Commit 06199b0

Browse files
committed
Add BC for timm- encoders
1 parent 4f3b37e commit 06199b0

File tree

1 file changed

+26
-0
lines changed
  • segmentation_models_pytorch/base

1 file changed

+26
-0
lines changed

segmentation_models_pytorch/base/model.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,29 @@ def predict(self, x):
9090
x = self.forward(x)
9191

9292
return x
93+
94+
def load_state_dict(self, state_dict, **kwargs):
95+
# for compatibility of weights for
96+
# timm- ported encoders with TimmUniversalEncoder
97+
from segmentation_models_pytorch.encoders import TimmUniversalEncoder
98+
99+
if not isinstance(self.encoder, TimmUniversalEncoder):
100+
return super().load_state_dict(state_dict, **kwargs)
101+
102+
patterns = ["regnet", "res2", "resnest", "mobilenetv3", "gernet"]
103+
104+
is_deprecated_encoder = any(
105+
self.encoder.name.startswith(pattern) for pattern in patterns
106+
)
107+
108+
if is_deprecated_encoder:
109+
keys = list(state_dict.keys())
110+
for key in keys:
111+
new_key = key
112+
if key.startswith("encoder.") and not key.startswith("encoder.model."):
113+
new_key = "encoder.model." + key.removeprefix("encoder.")
114+
if "gernet" in self.encoder.name:
115+
new_key = new_key.replace(".stages.", ".stages_")
116+
state_dict[new_key] = state_dict.pop(key)
117+
118+
return super().load_state_dict(state_dict, **kwargs)

0 commit comments

Comments
 (0)