Skip to content

Commit da0cd19

Browse files
committed
Update get_stages
1 parent e12ee8d commit da0cd19

File tree

12 files changed

+85
-52
lines changed

12 files changed

+85
-52
lines changed

segmentation_models_pytorch/encoders/_base.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import torch
2+
from typing import Sequence
3+
14
from . import _utils as utils
25

36

@@ -31,28 +34,21 @@ def set_in_channels(self, in_channels, pretrained=True):
3134
model=self, new_in_channels=in_channels, pretrained=pretrained
3235
)
3336

34-
def get_stages(self):
35-
"""Override it in your implementation"""
37+
def get_stages(self) -> dict[int, Sequence[torch.nn.Module]]:
38+
"""Override it in your implementation, should return a dictionary with keys as
39+
the output stride and values as the list of modules
40+
"""
3641
raise NotImplementedError
3742

3843
def make_dilated(self, output_stride):
39-
if output_stride == 16:
40-
stage_list = [5]
41-
dilation_list = [2]
42-
43-
elif output_stride == 8:
44-
stage_list = [4, 5]
45-
dilation_list = [2, 4]
46-
47-
else:
48-
raise ValueError(
49-
"Output stride should be 16 or 8, got {}.".format(output_stride)
50-
)
51-
52-
self._output_stride = output_stride
44+
if output_stride not in [8, 16]:
45+
raise ValueError(f"Output stride should be 16 or 8, got {output_stride}.")
5346

5447
stages = self.get_stages()
55-
for stage_indx, dilation_rate in zip(stage_list, dilation_list):
56-
utils.replace_strides_with_dilation(
57-
module=stages[stage_indx], dilation_rate=dilation_rate
58-
)
48+
for stage_stride, stage_modules in stages.items():
49+
if stage_stride <= output_stride:
50+
continue
51+
52+
dilation_rate = stage_stride // output_stride
53+
for module in stage_modules:
54+
utils.replace_strides_with_dilation(module, dilation_rate)

segmentation_models_pytorch/encoders/dpn.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ def __init__(self, stage_idxs, out_channels, depth=5, **kwargs):
4343

4444
del self.last_linear
4545

46+
def get_stages(self):
47+
return {
48+
16: self.features[self._stage_idxs[1] : self._stage_idxs[2]],
49+
32: self.features[self._stage_idxs[2] : self._stage_idxs[3]],
50+
}
51+
4652
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
4753
features = [x]
4854

segmentation_models_pytorch/encoders/efficientnet.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ def __init__(self, stage_idxs, out_channels, model_name, depth=5):
4444

4545
del self._fc
4646

47+
def get_stages(self):
48+
return {
49+
16: self._blocks[self._stage_idxs[1] : self._stage_idxs[2]],
50+
32: self._blocks[self._stage_idxs[2] :],
51+
}
52+
4753
def apply_blocks(
4854
self, x: torch.Tensor, start_idx: int, end_idx: int
4955
) -> torch.Tensor:

segmentation_models_pytorch/encoders/mix_transformer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,12 @@ def __init__(self, out_channels, depth=5, **kwargs):
526526
self._depth = depth
527527
self._in_channels = 3
528528

529+
def get_stages(self):
530+
return {
531+
16: [self.patch_embed3, self.block3, self.norm3],
532+
32: [self.patch_embed4, self.block4, self.norm4],
533+
}
534+
529535
def forward(self, x):
530536
# create dummy output for the first block
531537
batch_size, _, height, width = x.shape

segmentation_models_pytorch/encoders/mobilenet.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ def __init__(self, out_channels, depth=5, **kwargs):
3636
self._in_channels = 3
3737
del self.classifier
3838

39+
def get_stages(self):
40+
return {
41+
16: self.features[7:14],
42+
32: self.features[14:],
43+
}
44+
3945
def forward(self, x):
4046
features = [x]
4147

segmentation_models_pytorch/encoders/mobileone.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,12 @@ def __init__(
355355
num_se_blocks=num_blocks_per_stage[3] if use_se else 0,
356356
)
357357

358+
def get_stages(self):
359+
return {
360+
16: self.stage3,
361+
32: self.stage4,
362+
}
363+
358364
def _make_stage(
359365
self, planes: int, num_blocks: int, num_se_blocks: int
360366
) -> nn.Sequential:

segmentation_models_pytorch/encoders/resnet.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ def __init__(self, out_channels, depth=5, **kwargs):
4444
del self.fc
4545
del self.avgpool
4646

47+
def get_stages(self):
48+
return {
49+
16: self.layer3,
50+
32: self.layer4,
51+
}
52+
4753
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
4854
features = [x]
4955

segmentation_models_pytorch/encoders/senet.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
2424
"""
2525

26-
import torch.nn as nn
27-
2826
from pretrainedmodels.models.senet import (
2927
SENet,
3028
SEBottleneck,
@@ -46,14 +44,10 @@ def __init__(self, out_channels, depth=5, **kwargs):
4644
del self.avg_pool
4745

4846
def get_stages(self):
49-
return [
50-
nn.Identity(),
51-
self.layer0[:-1],
52-
nn.Sequential(self.layer0[-1], self.layer1),
53-
self.layer2,
54-
self.layer3,
55-
self.layer4,
56-
]
47+
return {
48+
16: self.layer3,
49+
32: self.layer4,
50+
}
5751

5852
def forward(self, x):
5953
features = [x]

segmentation_models_pytorch/encoders/timm_efficientnet.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,12 @@ def __init__(self, stage_idxs, out_channels, depth=5, **kwargs):
105105

106106
del self.classifier
107107

108+
def get_stages(self):
109+
return {
110+
16: self.blocks[self._stage_idxs[1] : self._stage_idxs[2]],
111+
32: self.blocks[self._stage_idxs[2] :],
112+
}
113+
108114
def forward(self, x):
109115
features = [x]
110116

segmentation_models_pytorch/encoders/timm_sknet.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ def __init__(self, out_channels, depth=5, **kwargs):
1313
del self.fc
1414
del self.global_pool
1515

16+
def get_stages(self):
17+
return {
18+
16: self.layer3,
19+
32: self.layer4,
20+
}
21+
1622
def forward(self, x):
1723
features = [x]
1824

0 commit comments

Comments
 (0)