Skip to content

Commit 4aa7cd1

Browse files
committed
Update encoder
1 parent ce1ae43 commit 4aa7cd1

File tree

1 file changed

+13
-26
lines changed

1 file changed

+13
-26
lines changed

segmentation_models_pytorch/encoders/efficientnet.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@
3131

3232

3333
class EfficientNetEncoder(EfficientNet, EncoderMixin):
34-
_is_torch_scriptable = False
35-
3634
def __init__(
3735
self,
3836
stage_idxs: List[int],
@@ -41,7 +39,7 @@ def __init__(
4139
depth: int = 5,
4240
output_stride: int = 32,
4341
):
44-
if depth > 5 or depth < 1:
42+
if depth > 5 or depth < 2:
4543
raise ValueError(
4644
f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}"
4745
)
@@ -50,11 +48,13 @@ def __init__(
5048
super().__init__(blocks_args, global_params)
5149

5250
self._stage_idxs = stage_idxs
51+
self._out_indexes = [x - 1 for x in stage_idxs]
5352
self._depth = depth
5453
self._in_channels = 3
5554
self._out_channels = out_channels
5655
self._output_stride = output_stride
5756

57+
self._drop_connect_rate = self._global_params.drop_connect_rate
5858
del self._fc
5959

6060
def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]:
@@ -63,17 +63,6 @@ def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]:
6363
32: [self._blocks[self._stage_idxs[2] :]],
6464
}
6565

66-
def apply_blocks(
67-
self, x: torch.Tensor, start_idx: int, end_idx: int
68-
) -> torch.Tensor:
69-
drop_connect_rate = self._global_params.drop_connect_rate
70-
71-
for block_number in range(start_idx, end_idx):
72-
drop_connect_prob = drop_connect_rate * block_number / len(self._blocks)
73-
x = self._blocks[block_number](x, drop_connect_prob)
74-
75-
return x
76-
7766
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
7867
features = [x]
7968

@@ -83,21 +72,19 @@ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
8372
x = self._swish(x)
8473
features.append(x)
8574

86-
if self._depth >= 2:
87-
x = self.apply_blocks(x, 0, self._stage_idxs[0])
88-
features.append(x)
75+
depth = 1
76+
for i, block in enumerate(self._blocks):
77+
drop_connect_prob = self._drop_connect_rate * i / len(self._blocks)
78+
x = block(x, drop_connect_prob)
8979

90-
if self._depth >= 3:
91-
x = self.apply_blocks(x, self._stage_idxs[0], self._stage_idxs[1])
92-
features.append(x)
80+
if i in self._out_indexes:
81+
features.append(x)
82+
depth += 1
9383

94-
if self._depth >= 4:
95-
x = self.apply_blocks(x, self._stage_idxs[1], self._stage_idxs[2])
96-
features.append(x)
84+
if not torch.jit.is_scripting() and depth > self._depth:
85+
break
9786

98-
if self._depth >= 5:
99-
x = self.apply_blocks(x, self._stage_idxs[2], len(self._blocks))
100-
features.append(x)
87+
features = features[: self._depth + 1]
10188

10289
return features
10390

0 commit comments

Comments
 (0)