@@ -156,33 +156,47 @@ def load_state_dict(self, state_dict, **kwargs):
156
156
class EfficientNetEncoder (EfficientNetBaseEncoder ):
157
157
def __init__ (
158
158
self ,
159
- stage_idxs ,
160
- out_channels ,
161
- depth = 5 ,
162
- channel_multiplier = 1.0 ,
163
- depth_multiplier = 1.0 ,
164
- drop_rate = 0.2 ,
159
+ stage_idxs : List [int ],
160
+ out_channels : List [int ],
161
+ depth : int = 5 ,
162
+ channel_multiplier : float = 1.0 ,
163
+ depth_multiplier : float = 1.0 ,
164
+ drop_rate : float = 0.2 ,
165
+ output_stride : int = 32 ,
165
166
):
166
167
kwargs = get_efficientnet_kwargs (
167
168
channel_multiplier , depth_multiplier , drop_rate
168
169
)
169
- super ().__init__ (stage_idxs , out_channels , depth , ** kwargs )
170
+ super ().__init__ (
171
+ stage_idxs = stage_idxs ,
172
+ depth = depth ,
173
+ out_channels = out_channels ,
174
+ output_stride = output_stride ,
175
+ ** kwargs ,
176
+ )
170
177
171
178
172
179
class EfficientNetLiteEncoder (EfficientNetBaseEncoder ):
173
180
def __init__ (
174
181
self ,
175
- stage_idxs ,
176
- out_channels ,
177
- depth = 5 ,
178
- channel_multiplier = 1.0 ,
179
- depth_multiplier = 1.0 ,
180
- drop_rate = 0.2 ,
182
+ stage_idxs : List [int ],
183
+ out_channels : List [int ],
184
+ depth : int = 5 ,
185
+ channel_multiplier : float = 1.0 ,
186
+ depth_multiplier : float = 1.0 ,
187
+ drop_rate : float = 0.2 ,
188
+ output_stride : int = 32 ,
181
189
):
182
190
kwargs = gen_efficientnet_lite_kwargs (
183
191
channel_multiplier , depth_multiplier , drop_rate
184
192
)
185
- super ().__init__ (stage_idxs , out_channels , depth , ** kwargs )
193
+ super ().__init__ (
194
+ stage_idxs = stage_idxs ,
195
+ depth = depth ,
196
+ out_channels = out_channels ,
197
+ output_stride = output_stride ,
198
+ ** kwargs ,
199
+ )
186
200
187
201
188
202
def prepare_settings (settings ):
0 commit comments