Skip to content

Commit d44509a

Browse files
committed
Fix encoders (mobilenet, inceptionv4)
1 parent ead24b4 commit d44509a

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

segmentation_models_pytorch/encoders/inceptionv4.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
class InceptionV4Encoder(InceptionV4, EncoderMixin):
3636
def __init__(
3737
self,
38-
out_indexes: List[int],
3938
out_channels: List[int],
4039
depth: int = 5,
4140
output_stride: int = 32,
@@ -47,7 +46,7 @@ def __init__(
4746
self._in_channels = 3
4847
self._out_channels = out_channels
4948
self._output_stride = output_stride
50-
self._out_indexes = out_indexes
49+
self._out_indexes = [2, 4, 8, 14, len(self.features) - 1]
5150

5251
# correct paddings
5352
for m in self.modules():
@@ -115,7 +114,6 @@ def load_state_dict(self, state_dict, **kwargs):
115114
},
116115
},
117116
"params": {
118-
"out_indexes": [2, 4, 8, 14],
119117
"out_channels": [3, 64, 192, 384, 1024, 1536],
120118
"num_classes": 1001,
121119
},

segmentation_models_pytorch/encoders/mobilenet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(
4040
self._in_channels = 3
4141
self._out_channels = out_channels
4242
self._output_stride = output_stride
43-
self._out_indexes = [2, 4, 7, 14]
43+
self._out_indexes = [1, 3, 6, 13, len(self.features) - 1]
4444

4545
del self.classifier
4646

tests/models/base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,17 @@ def test_classification_head(self):
148148

149149
def test_any_resolution(self):
150150
model = self.get_default_model()
151-
if model.requires_divisible_input_shape:
152-
self.skipTest("Model requires divisible input shape")
153151

154152
sample = self._get_sample(
155153
height=self.default_height + 3,
156154
width=self.default_width + 7,
157155
).to(default_device)
158156

157+
if model.requires_divisible_input_shape:
158+
with self.assertRaises(RuntimeError, msg="Wrong input shape"):
159+
output = model(sample)
160+
return
161+
159162
with torch.inference_mode():
160163
output = model(sample)
161164

0 commit comments

Comments
 (0)