31
31
32
32
33
33
class EfficientNetEncoder (EfficientNet , EncoderMixin ):
34
- _is_torch_scriptable = False
35
-
36
34
def __init__ (
37
35
self ,
38
36
stage_idxs : List [int ],
@@ -41,7 +39,7 @@ def __init__(
41
39
depth : int = 5 ,
42
40
output_stride : int = 32 ,
43
41
):
44
- if depth > 5 or depth < 1 :
42
+ if depth > 5 or depth < 2 :
45
43
raise ValueError (
46
44
f"{ self .__class__ .__name__ } depth should be in range [1, 5], got { depth } "
47
45
)
@@ -50,11 +48,13 @@ def __init__(
50
48
super ().__init__ (blocks_args , global_params )
51
49
52
50
self ._stage_idxs = stage_idxs
51
+ self ._out_indexes = [x - 1 for x in stage_idxs ]
53
52
self ._depth = depth
54
53
self ._in_channels = 3
55
54
self ._out_channels = out_channels
56
55
self ._output_stride = output_stride
57
56
57
+ self ._drop_connect_rate = self ._global_params .drop_connect_rate
58
58
del self ._fc
59
59
60
60
def get_stages (self ) -> Dict [int , Sequence [torch .nn .Module ]]:
@@ -63,17 +63,6 @@ def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]:
63
63
32 : [self ._blocks [self ._stage_idxs [2 ] :]],
64
64
}
65
65
66
- def apply_blocks (
67
- self , x : torch .Tensor , start_idx : int , end_idx : int
68
- ) -> torch .Tensor :
69
- drop_connect_rate = self ._global_params .drop_connect_rate
70
-
71
- for block_number in range (start_idx , end_idx ):
72
- drop_connect_prob = drop_connect_rate * block_number / len (self ._blocks )
73
- x = self ._blocks [block_number ](x , drop_connect_prob )
74
-
75
- return x
76
-
77
66
def forward (self , x : torch .Tensor ) -> List [torch .Tensor ]:
78
67
features = [x ]
79
68
@@ -83,21 +72,19 @@ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
83
72
x = self ._swish (x )
84
73
features .append (x )
85
74
86
- if self ._depth >= 2 :
87
- x = self .apply_blocks (x , 0 , self ._stage_idxs [0 ])
88
- features .append (x )
75
+ depth = 1
76
+ for i , block in enumerate (self ._blocks ):
77
+ drop_connect_prob = self ._drop_connect_rate * i / len (self ._blocks )
78
+ x = block (x , drop_connect_prob )
89
79
90
- if self . _depth >= 3 :
91
- x = self . apply_blocks ( x , self . _stage_idxs [ 0 ], self . _stage_idxs [ 1 ] )
92
- features . append ( x )
80
+ if i in self . _out_indexes :
81
+ features . append ( x )
82
+ depth += 1
93
83
94
- if self ._depth >= 4 :
95
- x = self .apply_blocks (x , self ._stage_idxs [1 ], self ._stage_idxs [2 ])
96
- features .append (x )
84
+ if not torch .jit .is_scripting () and depth > self ._depth :
85
+ break
97
86
98
- if self ._depth >= 5 :
99
- x = self .apply_blocks (x , self ._stage_idxs [2 ], len (self ._blocks ))
100
- features .append (x )
87
+ features = features [: self ._depth + 1 ]
101
88
102
89
return features
103
90
0 commit comments