3131
3232
3333class EfficientNetEncoder (EfficientNet , EncoderMixin ):
34- _is_torch_scriptable = False
35-
3634 def __init__ (
3735 self ,
3836 stage_idxs : List [int ],
@@ -41,7 +39,7 @@ def __init__(
4139 depth : int = 5 ,
4240 output_stride : int = 32 ,
4341 ):
44- if depth > 5 or depth < 1 :
42+ if depth > 5 or depth < 2 :
4543 raise ValueError (
4644 f"{ self .__class__ .__name__ } depth should be in range [1, 5], got { depth } "
4745 )
@@ -50,11 +48,13 @@ def __init__(
5048 super ().__init__ (blocks_args , global_params )
5149
5250 self ._stage_idxs = stage_idxs
51+ self ._out_indexes = [x - 1 for x in stage_idxs ]
5352 self ._depth = depth
5453 self ._in_channels = 3
5554 self ._out_channels = out_channels
5655 self ._output_stride = output_stride
5756
57+ self ._drop_connect_rate = self ._global_params .drop_connect_rate
5858 del self ._fc
5959
6060 def get_stages (self ) -> Dict [int , Sequence [torch .nn .Module ]]:
@@ -63,17 +63,6 @@ def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]:
6363 32 : [self ._blocks [self ._stage_idxs [2 ] :]],
6464 }
6565
66- def apply_blocks (
67- self , x : torch .Tensor , start_idx : int , end_idx : int
68- ) -> torch .Tensor :
69- drop_connect_rate = self ._global_params .drop_connect_rate
70-
71- for block_number in range (start_idx , end_idx ):
72- drop_connect_prob = drop_connect_rate * block_number / len (self ._blocks )
73- x = self ._blocks [block_number ](x , drop_connect_prob )
74-
75- return x
76-
7766 def forward (self , x : torch .Tensor ) -> List [torch .Tensor ]:
7867 features = [x ]
7968
@@ -83,21 +72,19 @@ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
8372 x = self ._swish (x )
8473 features .append (x )
8574
86- if self ._depth >= 2 :
87- x = self .apply_blocks (x , 0 , self ._stage_idxs [0 ])
88- features .append (x )
75+ depth = 1
76+ for i , block in enumerate (self ._blocks ):
77+ drop_connect_prob = self ._drop_connect_rate * i / len (self ._blocks )
78+ x = block (x , drop_connect_prob )
8979
90- if self . _depth >= 3 :
91- x = self . apply_blocks ( x , self . _stage_idxs [ 0 ], self . _stage_idxs [ 1 ] )
92- features . append ( x )
80+ if i in self . _out_indexes :
81+ features . append ( x )
82+ depth += 1
9383
94- if self ._depth >= 4 :
95- x = self .apply_blocks (x , self ._stage_idxs [1 ], self ._stage_idxs [2 ])
96- features .append (x )
84+ if not torch .jit .is_scripting () and depth > self ._depth :
85+ break
9786
98- if self ._depth >= 5 :
99- x = self .apply_blocks (x , self ._stage_idxs [2 ], len (self ._blocks ))
100- features .append (x )
87+ features = features [: self ._depth + 1 ]
10188
10289 return features
10390
0 commit comments