Skip to content

Commit 7752969

Browse files
committed
Fix weight loading for deprecate encoders
1 parent da0cd19 commit 7752969

File tree

2 files changed

+22
-16
lines changed

2 files changed

+22
-16
lines changed

segmentation_models_pytorch/base/model.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from . import initialization as init
55
from .hub_mixin import SMPHubMixin
6-
from ..encoders.timm_universal import TimmUniversalEncoder
76

87
T = TypeVar("T", bound="SegmentationModel")
98

@@ -82,18 +81,3 @@ def predict(self, x):
8281
x = self.forward(x)
8382

8483
return x
85-
86-
def load_state_dict(self, state_dict, **kwargs):
87-
# for compatibility of weights for
88-
# timm- ported encoders with TimmUniversalEncoder
89-
if isinstance(self.encoder, TimmUniversalEncoder):
90-
keys = list(state_dict.keys())
91-
for key in keys:
92-
new_key = key
93-
if key.startswith("encoder.") and not key.startswith("encoder.model."):
94-
new_key = key.replace("encoder.", "encoder.model.")
95-
if "gernet" in self.encoder.name:
96-
new_key = new_key.replace(".stages.", ".stages_")
97-
state_dict[new_key] = state_dict.pop(key)
98-
99-
return super().load_state_dict(state_dict, **kwargs)

segmentation_models_pytorch/encoders/timm_universal.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,28 @@ def output_stride(self) -> int:
196196
"""
197197
return min(self._output_stride, 2**self._depth)
198198

199+
def load_state_dict(self, state_dict, **kwargs):
200+
# for compatibility of weights for
201+
# timm- ported encoders with TimmUniversalEncoder
202+
203+
patterns = ["regnet", "res2", "resnest", "mobilenetv3", "gernet"]
204+
205+
is_deprecated_encoder = any(
206+
self.name.startswith(pattern) for pattern in patterns
207+
)
208+
209+
if is_deprecated_encoder:
210+
keys = list(state_dict.keys())
211+
for key in keys:
212+
new_key = key
213+
if not key.startswith("model."):
214+
new_key = "model." + key
215+
if "gernet" in self.name:
216+
new_key = new_key.replace(".stages.", ".stages_")
217+
state_dict[new_key] = state_dict.pop(key)
218+
219+
return super().load_state_dict(state_dict, **kwargs)
220+
199221

200222
def _merge_kwargs_no_duplicates(a: dict[str, Any], b: dict[str, Any]) -> dict[str, Any]:
201223
"""

0 commit comments

Comments
 (0)