Skip to content

Commit b8586c7

Browse files
committed
Rename for PAN
1 parent f814adc commit b8586c7

File tree

2 files changed

+19
-19
lines changed

2 files changed

+19
-19
lines changed

segmentation_models_pytorch/decoders/pan/decoder.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -168,12 +168,12 @@ def __init__(
168168
self,
169169
in_channels: int,
170170
out_channels: int,
171-
decoder_interpolation_mode: str = "bilinear",
171+
interpolation_mode: str = "bilinear",
172172
):
173173
super(GAUBlock, self).__init__()
174174

175-
self.interpolation_mode = decoder_interpolation_mode
176-
self.align_corners = True if decoder_interpolation_mode == "bilinear" else None
175+
self.interpolation_mode = interpolation_mode
176+
self.align_corners = True if interpolation_mode == "bilinear" else None
177177

178178
self.conv1 = nn.Sequential(
179179
nn.AdaptiveAvgPool2d(1),
@@ -214,7 +214,7 @@ def __init__(
214214
encoder_channels: Sequence[int],
215215
encoder_depth: Literal[3, 4, 5],
216216
decoder_channels: int,
217-
decoder_interpolation_mode: str = "bilinear",
217+
interpolation_mode: str = "bilinear",
218218
):
219219
super().__init__()
220220

@@ -235,19 +235,19 @@ def __init__(
235235
self.gau3 = GAUBlock(
236236
in_channels=encoder_channels[2],
237237
out_channels=decoder_channels,
238-
decoder_interpolation_mode=decoder_interpolation_mode,
238+
interpolation_mode=interpolation_mode,
239239
)
240240
if encoder_depth >= 4:
241241
self.gau2 = GAUBlock(
242242
in_channels=encoder_channels[1],
243243
out_channels=decoder_channels,
244-
decoder_interpolation_mode=decoder_interpolation_mode,
244+
interpolation_mode=interpolation_mode,
245245
)
246246
if encoder_depth >= 3:
247247
self.gau1 = GAUBlock(
248248
in_channels=encoder_channels[0],
249249
out_channels=decoder_channels,
250-
decoder_interpolation_mode=decoder_interpolation_mode,
250+
interpolation_mode=interpolation_mode,
251251
)
252252

253253
def forward(self, features: List[torch.Tensor]) -> torch.Tensor:

segmentation_models_pytorch/decoders/pan/model.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class PAN(SegmentationModel):
3131
encoder_output_stride: 16 or 32, if 16 use dilation in encoder last layer.
3232
Doesn't work with ***ception***, **vgg***, **densenet*`** backbones.Default is 16.
3333
decoder_channels: A number of convolution layer filters in decoder blocks
34-
decoder_interpolation_mode: Interpolation mode used in decoder of the model. Available options are
34+
decoder_interpolation: Interpolation mode used in decoder of the model. Available options are
3535
**"nearest"**, **"bilinear"**, **"bicubic"**, **"area"**, **"nearest-exact"**. Default is **"bilinear"**.
3636
in_channels: A number of input channels for the model, default is 3 (RGB images)
3737
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
@@ -65,7 +65,7 @@ def __init__(
6565
encoder_weights: Optional[str] = "imagenet",
6666
encoder_output_stride: Literal[16, 32] = 16,
6767
decoder_channels: int = 32,
68-
decoder_interpolation_mode: str = "bilinear",
68+
decoder_interpolation: str = "bilinear",
6969
in_channels: int = 3,
7070
classes: int = 1,
7171
activation: Optional[Union[str, Callable]] = None,
@@ -82,6 +82,15 @@ def __init__(
8282
)
8383
)
8484

85+
upscale_mode = kwargs.pop("upscale_mode", None)
86+
if upscale_mode is not None:
87+
warnings.warn(
88+
"The usage of upscale_mode is deprecated. Please modify your code for decoder_interpolation",
89+
DeprecationWarning,
90+
stacklevel=2,
91+
)
92+
decoder_interpolation = upscale_mode
93+
8594
self.encoder = get_encoder(
8695
encoder_name,
8796
in_channels=in_channels,
@@ -91,20 +100,11 @@ def __init__(
91100
**kwargs,
92101
)
93102

94-
upscale_mode = kwargs.pop("upscale_mode", None)
95-
if upscale_mode is not None:
96-
warnings.warn(
97-
"The usage of upscale_mode is deprecated. Please modify your code for decoder_interpolation_mode",
98-
DeprecationWarning,
99-
stacklevel=2,
100-
)
101-
decoder_interpolation_mode = upscale_mode
102-
103103
self.decoder = PANDecoder(
104104
encoder_channels=self.encoder.out_channels,
105105
encoder_depth=encoder_depth,
106106
decoder_channels=decoder_channels,
107-
decoder_interpolation_mode=decoder_interpolation_mode,
107+
interpolation_mode=decoder_interpolation,
108108
)
109109

110110
self.segmentation_head = SegmentationHead(

0 commit comments

Comments
 (0)