Skip to content

Commit 556b3aa

Browse files
committed
Fix DeepLabV3 BC
1 parent 31bee79 commit 556b3aa

File tree

2 files changed

+16
-18
lines changed

2 files changed

+16
-18
lines changed

segmentation_models_pytorch/decoders/deeplabv3/decoder.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
__all__ = ["DeepLabV3Decoder", "DeepLabV3PlusDecoder"]
4141

4242

43-
class DeepLabV3Decoder(nn.Sequential):
43+
class DeepLabV3Decoder(nn.Module):
4444
def __init__(
4545
self,
4646
in_channels: int,
@@ -69,23 +69,6 @@ def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
6969
x = self.relu(x)
7070
return x
7171

72-
def load_state_dict(self, state_dict, *args, **kwargs):
73-
# For backward compatibility, previously this module was Sequential
74-
# and was not scriptable.
75-
keys = list(state_dict.keys())
76-
for key in keys:
77-
new_key = key
78-
if key.startswith("0."):
79-
new_key = "aspp." + key[2:]
80-
elif key.startswith("1."):
81-
new_key = "conv." + key[2:]
82-
elif key.startswith("2."):
83-
new_key = "bn." + key[2:]
84-
elif key.startswith("3."):
85-
new_key = "relu." + key[2:]
86-
state_dict[new_key] = state_dict.pop(key)
87-
super().load_state_dict(state_dict, *args, **kwargs)
88-
8972

9073
class DeepLabV3PlusDecoder(nn.Module):
9174
def __init__(

segmentation_models_pytorch/decoders/deeplabv3/model.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,21 @@ def __init__(
121121
else:
122122
self.classification_head = None
123123

124+
def load_state_dict(self, state_dict, *args, **kwargs):
125+
# For backward compatibility, previously Decoder module was Sequential
126+
# and was not scriptable.
127+
keys = list(state_dict.keys())
128+
for key in keys:
129+
new_key = key
130+
if key.startswith("decoder.0."):
131+
new_key = key.replace("decoder.0.", "decoder.aspp.")
132+
elif key.startswith("decoder.1."):
133+
new_key = key.replace("decoder.1.", "decoder.conv.")
134+
elif key.startswith("decoder.2."):
135+
new_key = key.replace("decoder.2.", "decoder.bn.")
136+
state_dict[new_key] = state_dict.pop(key)
137+
return super().load_state_dict(state_dict, *args, **kwargs)
138+
124139

125140
class DeepLabV3Plus(SegmentationModel):
126141
"""DeepLabV3+ implementation from "Encoder-Decoder with Atrous Separable

0 commit comments

Comments
 (0)