File tree Expand file tree Collapse file tree 2 files changed +16
-18
lines changed
segmentation_models_pytorch/decoders/deeplabv3 Expand file tree Collapse file tree 2 files changed +16
-18
lines changed Original file line number Diff line number Diff line change 40
40
__all__ = ["DeepLabV3Decoder" , "DeepLabV3PlusDecoder" ]
41
41
42
42
43
- class DeepLabV3Decoder (nn .Sequential ):
43
+ class DeepLabV3Decoder (nn .Module ):
44
44
def __init__ (
45
45
self ,
46
46
in_channels : int ,
@@ -69,23 +69,6 @@ def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
69
69
x = self .relu (x )
70
70
return x
71
71
72
- def load_state_dict (self , state_dict , * args , ** kwargs ):
73
- # For backward compatibility, previously this module was Sequential
74
- # and was not scriptable.
75
- keys = list (state_dict .keys ())
76
- for key in keys :
77
- new_key = key
78
- if key .startswith ("0." ):
79
- new_key = "aspp." + key [2 :]
80
- elif key .startswith ("1." ):
81
- new_key = "conv." + key [2 :]
82
- elif key .startswith ("2." ):
83
- new_key = "bn." + key [2 :]
84
- elif key .startswith ("3." ):
85
- new_key = "relu." + key [2 :]
86
- state_dict [new_key ] = state_dict .pop (key )
87
- super ().load_state_dict (state_dict , * args , ** kwargs )
88
-
89
72
90
73
class DeepLabV3PlusDecoder (nn .Module ):
91
74
def __init__ (
Original file line number Diff line number Diff line change @@ -121,6 +121,21 @@ def __init__(
121
121
else :
122
122
self .classification_head = None
123
123
124
+ def load_state_dict (self , state_dict , * args , ** kwargs ):
125
+ # For backward compatibility, previously Decoder module was Sequential
126
+ # and was not scriptable.
127
+ keys = list (state_dict .keys ())
128
+ for key in keys :
129
+ new_key = key
130
+ if key .startswith ("decoder.0." ):
131
+ new_key = key .replace ("decoder.0." , "decoder.aspp." )
132
+ elif key .startswith ("decoder.1." ):
133
+ new_key = key .replace ("decoder.1." , "decoder.conv." )
134
+ elif key .startswith ("decoder.2." ):
135
+ new_key = key .replace ("decoder.2." , "decoder.bn." )
136
+ state_dict [new_key ] = state_dict .pop (key )
137
+ return super ().load_state_dict (state_dict , * args , ** kwargs )
138
+
124
139
125
140
class DeepLabV3Plus (SegmentationModel ):
126
141
"""DeepLabV3+ implementation from "Encoder-Decoder with Atrous Separable
You can’t perform that action at this time.
0 commit comments