File tree Expand file tree Collapse file tree 2 files changed +16
-18
lines changed
segmentation_models_pytorch/decoders/deeplabv3 Expand file tree Collapse file tree 2 files changed +16
-18
lines changed Original file line number Diff line number Diff line change 4040__all__ = ["DeepLabV3Decoder" , "DeepLabV3PlusDecoder" ]
4141
4242
43- class DeepLabV3Decoder (nn .Sequential ):
43+ class DeepLabV3Decoder (nn .Module ):
4444 def __init__ (
4545 self ,
4646 in_channels : int ,
@@ -69,23 +69,6 @@ def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
6969 x = self .relu (x )
7070 return x
7171
72- def load_state_dict (self , state_dict , * args , ** kwargs ):
73- # For backward compatibility, previously this module was Sequential
74- # and was not scriptable.
75- keys = list (state_dict .keys ())
76- for key in keys :
77- new_key = key
78- if key .startswith ("0." ):
79- new_key = "aspp." + key [2 :]
80- elif key .startswith ("1." ):
81- new_key = "conv." + key [2 :]
82- elif key .startswith ("2." ):
83- new_key = "bn." + key [2 :]
84- elif key .startswith ("3." ):
85- new_key = "relu." + key [2 :]
86- state_dict [new_key ] = state_dict .pop (key )
87- super ().load_state_dict (state_dict , * args , ** kwargs )
88-
8972
9073class DeepLabV3PlusDecoder (nn .Module ):
9174 def __init__ (
Original file line number Diff line number Diff line change @@ -121,6 +121,21 @@ def __init__(
121121 else :
122122 self .classification_head = None
123123
124+ def load_state_dict (self , state_dict , * args , ** kwargs ):
125+ # For backward compatibility, previously Decoder module was Sequential
126+ # and was not scriptable.
127+ keys = list (state_dict .keys ())
128+ for key in keys :
129+ new_key = key
130+ if key .startswith ("decoder.0." ):
131+ new_key = key .replace ("decoder.0." , "decoder.aspp." )
132+ elif key .startswith ("decoder.1." ):
133+ new_key = key .replace ("decoder.1." , "decoder.conv." )
134+ elif key .startswith ("decoder.2." ):
135+ new_key = key .replace ("decoder.2." , "decoder.bn." )
136+ state_dict [new_key ] = state_dict .pop (key )
137+ return super ().load_state_dict (state_dict , * args , ** kwargs )
138+
124139
125140class DeepLabV3Plus (SegmentationModel ):
126141 """DeepLabV3+ implementation from "Encoder-Decoder with Atrous Separable
You can’t perform that action at this time.
0 commit comments