Skip to content

Commit ebe45dc

Browse files
committed
Update indexes
1 parent 22e6b2e commit ebe45dc

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

segmentation_models_pytorch/encoders/efficientnet.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
class EfficientNetEncoder(EfficientNet, EncoderMixin):
3434
def __init__(
3535
self,
36-
stage_idxs: List[int],
36+
out_indexes: List[int],
3737
out_channels: List[int],
3838
model_name: str,
3939
depth: int = 5,
@@ -47,8 +47,7 @@ def __init__(
4747
blocks_args, global_params = get_model_params(model_name, override_params=None)
4848
super().__init__(blocks_args, global_params)
4949

50-
self._stage_idxs = stage_idxs
51-
self._out_indexes = [x - 1 for x in stage_idxs]
50+
self._out_indexes = out_indexes
5251
self._depth = depth
5352
self._in_channels = 3
5453
self._out_channels = out_channels
@@ -109,7 +108,7 @@ def load_state_dict(self, state_dict, **kwargs):
109108
},
110109
"params": {
111110
"out_channels": [3, 32, 24, 40, 112, 320],
112-
"stage_idxs": [3, 5, 9, 16],
111+
"out_indexes": [2, 4, 8, 15],
113112
"model_name": "efficientnet-b0",
114113
},
115114
},
@@ -127,7 +126,7 @@ def load_state_dict(self, state_dict, **kwargs):
127126
},
128127
"params": {
129128
"out_channels": [3, 32, 24, 40, 112, 320],
130-
"stage_idxs": [5, 8, 16, 23],
129+
"out_indexes": [4, 7, 15, 22],
131130
"model_name": "efficientnet-b1",
132131
},
133132
},
@@ -145,7 +144,7 @@ def load_state_dict(self, state_dict, **kwargs):
145144
},
146145
"params": {
147146
"out_channels": [3, 32, 24, 48, 120, 352],
148-
"stage_idxs": [5, 8, 16, 23],
147+
"out_indexes": [4, 7, 15, 22],
149148
"model_name": "efficientnet-b2",
150149
},
151150
},
@@ -163,7 +162,7 @@ def load_state_dict(self, state_dict, **kwargs):
163162
},
164163
"params": {
165164
"out_channels": [3, 40, 32, 48, 136, 384],
166-
"stage_idxs": [5, 8, 18, 26],
165+
"out_indexes": [4, 7, 17, 25],
167166
"model_name": "efficientnet-b3",
168167
},
169168
},
@@ -181,7 +180,7 @@ def load_state_dict(self, state_dict, **kwargs):
181180
},
182181
"params": {
183182
"out_channels": [3, 48, 32, 56, 160, 448],
184-
"stage_idxs": [6, 10, 22, 32],
183+
"out_indexes": [5, 9, 21, 31],
185184
"model_name": "efficientnet-b4",
186185
},
187186
},
@@ -199,7 +198,7 @@ def load_state_dict(self, state_dict, **kwargs):
199198
},
200199
"params": {
201200
"out_channels": [3, 48, 40, 64, 176, 512],
202-
"stage_idxs": [8, 13, 27, 39],
201+
"out_indexes": [7, 12, 26, 38],
203202
"model_name": "efficientnet-b5",
204203
},
205204
},
@@ -217,7 +216,7 @@ def load_state_dict(self, state_dict, **kwargs):
217216
},
218217
"params": {
219218
"out_channels": [3, 56, 40, 72, 200, 576],
220-
"stage_idxs": [9, 15, 31, 45],
219+
"out_indexes": [8, 14, 30, 44],
221220
"model_name": "efficientnet-b6",
222221
},
223222
},
@@ -235,7 +234,7 @@ def load_state_dict(self, state_dict, **kwargs):
235234
},
236235
"params": {
237236
"out_channels": [3, 64, 48, 80, 224, 640],
238-
"stage_idxs": [11, 18, 38, 55],
237+
"out_indexes": [10, 17, 37, 54],
239238
"model_name": "efficientnet-b7",
240239
},
241240
},

0 commit comments

Comments
 (0)