@@ -57,7 +57,8 @@ class BiFPN(nn.Module):
5757 modified by Zylo117
5858 """
5959
60- def __init__ (self , num_channels , conv_channels , first_time = False , epsilon = 1e-4 , onnx_export = False , attention = True ):
60+ def __init__ (self , num_channels , conv_channels , first_time = False , epsilon = 1e-4 , onnx_export = False , attention = True ,
61+ use_p8 = False ):
6162 """
6263
6364 Args:
@@ -70,6 +71,8 @@ def __init__(self, num_channels, conv_channels, first_time=False, epsilon=1e-4,
7071 """
7172 super (BiFPN , self ).__init__ ()
7273 self .epsilon = epsilon
74+ self .use_p8 = use_p8
75+
7376 # Conv layers
7477 self .conv6_up = SeparableConvBlock (num_channels , onnx_export = onnx_export )
7578 self .conv5_up = SeparableConvBlock (num_channels , onnx_export = onnx_export )
@@ -79,6 +82,9 @@ def __init__(self, num_channels, conv_channels, first_time=False, epsilon=1e-4,
7982 self .conv5_down = SeparableConvBlock (num_channels , onnx_export = onnx_export )
8083 self .conv6_down = SeparableConvBlock (num_channels , onnx_export = onnx_export )
8184 self .conv7_down = SeparableConvBlock (num_channels , onnx_export = onnx_export )
85+ if use_p8 :
86+ self .conv7_up = SeparableConvBlock (num_channels , onnx_export = onnx_export )
87+ self .conv8_down = SeparableConvBlock (num_channels , onnx_export = onnx_export )
8288
8389 # Feature scaling layers
8490 self .p6_upsample = nn .Upsample (scale_factor = 2 , mode = 'nearest' )
@@ -90,6 +96,9 @@ def __init__(self, num_channels, conv_channels, first_time=False, epsilon=1e-4,
9096 self .p5_downsample = MaxPool2dStaticSamePadding (3 , 2 )
9197 self .p6_downsample = MaxPool2dStaticSamePadding (3 , 2 )
9298 self .p7_downsample = MaxPool2dStaticSamePadding (3 , 2 )
99+ if use_p8 :
100+ self .p7_upsample = nn .Upsample (scale_factor = 2 , mode = 'nearest' )
101+ self .p8_downsample = MaxPool2dStaticSamePadding (3 , 2 )
93102
94103 self .swish = MemoryEfficientSwish () if not onnx_export else Swish ()
95104
@@ -116,6 +125,10 @@ def __init__(self, num_channels, conv_channels, first_time=False, epsilon=1e-4,
116125 self .p6_to_p7 = nn .Sequential (
117126 MaxPool2dStaticSamePadding (3 , 2 )
118127 )
128+ if use_p8 :
129+ self .p7_to_p8 = nn .Sequential (
130+ MaxPool2dStaticSamePadding (3 , 2 )
131+ )
119132
120133 self .p4_down_channel_2 = nn .Sequential (
121134 Conv2dStaticSamePadding (conv_channels [1 ], num_channels , 1 ),
@@ -172,11 +185,11 @@ def forward(self, inputs):
172185 # elif later phase, upsample to target phase's by nearest interpolation
173186
174187 if self .attention :
175- p3_out , p4_out , p5_out , p6_out , p7_out = self ._forward_fast_attention (inputs )
188+ outs = self ._forward_fast_attention (inputs )
176189 else :
177- p3_out , p4_out , p5_out , p6_out , p7_out = self ._forward (inputs )
190+ outs = self ._forward (inputs )
178191
179- return p3_out , p4_out , p5_out , p6_out , p7_out
192+ return outs
180193
181194 def _forward_fast_attention (self , inputs ):
182195 if self .first_time :
@@ -258,19 +271,34 @@ def _forward(self, inputs):
258271
259272 p6_in = self .p5_to_p6 (p5 )
260273 p7_in = self .p6_to_p7 (p6_in )
274+ if self .use_p8 :
275+ p8_in = self .p7_to_p8 (p7_in )
261276
262277 p3_in = self .p3_down_channel (p3 )
263278 p4_in = self .p4_down_channel (p4 )
264279 p5_in = self .p5_down_channel (p5 )
265280
266281 else :
267- # P3_0, P4_0, P5_0, P6_0 and P7_0
268- p3_in , p4_in , p5_in , p6_in , p7_in = inputs
282+ if self .use_p8 :
283+ # P3_0, P4_0, P5_0, P6_0, P7_0 and P8_0
284+ p3_in , p4_in , p5_in , p6_in , p7_in , p8_in = inputs
285+ else :
286+ # P3_0, P4_0, P5_0, P6_0 and P7_0
287+ p3_in , p4_in , p5_in , p6_in , p7_in = inputs
269288
270- # P7_0 to P7_2
289+ if self .use_p8 :
290+ # P8_0 to P8_2
271291
272- # Connections for P6_0 and P7_0 to P6_1 respectively
273- p6_up = self .conv6_up (self .swish (p6_in + self .p6_upsample (p7_in )))
292+ # Connections for P7_0 and P8_0 to P7_1 respectively
293+ p7_up = self .conv7_up (self .swish (p7_in + self .p7_upsample (p8_in )))
294+
295+ # Connections for P6_0 and P7_0 to P6_1 respectively
296+ p6_up = self .conv6_up (self .swish (p6_in + self .p6_upsample (p7_up )))
297+ else :
298+ # P7_0 to P7_2
299+
300+ # Connections for P6_0 and P7_0 to P6_1 respectively
301+ p6_up = self .conv6_up (self .swish (p6_in + self .p6_upsample (p7_in )))
274302
275303 # Connections for P5_0 and P6_1 to P5_1 respectively
276304 p5_up = self .conv5_up (self .swish (p5_in + self .p5_upsample (p6_up )))
@@ -297,26 +325,36 @@ def _forward(self, inputs):
297325 p6_out = self .conv6_down (
298326 self .swish (p6_in + p6_up + self .p6_downsample (p5_out )))
299327
300- # Connections for P7_0 and P6_2 to P7_2
301- p7_out = self .conv7_down (self .swish (p7_in + self .p7_downsample (p6_out )))
328+ if self .use_p8 :
329+ # Connections for P7_0, P7_1 and P6_2 to P7_2 respectively
330+ p7_out = self .conv7_down (
331+ self .swish (p7_in + p7_up + self .p7_downsample (p6_out )))
302332
303- return p3_out , p4_out , p5_out , p6_out , p7_out
333+ # Connections for P8_0 and P7_2 to P8_2
334+ p8_out = self .conv8_down (self .swish (p8_in + self .p8_downsample (p7_out )))
335+
336+ return p3_out , p4_out , p5_out , p6_out , p7_out , p8_out
337+ else :
338+ # Connections for P7_0 and P6_2 to P7_2
339+ p7_out = self .conv7_down (self .swish (p7_in + self .p7_downsample (p6_out )))
340+
341+ return p3_out , p4_out , p5_out , p6_out , p7_out
304342
305343
306344class Regressor (nn .Module ):
307345 """
308346 modified by Zylo117
309347 """
310348
311- def __init__ (self , in_channels , num_anchors , num_layers , onnx_export = False ):
349+ def __init__ (self , in_channels , num_anchors , num_layers , pyramid_levels = 5 , onnx_export = False ):
312350 super (Regressor , self ).__init__ ()
313351 self .num_layers = num_layers
314352
315353 self .conv_list = nn .ModuleList (
316354 [SeparableConvBlock (in_channels , in_channels , norm = False , activation = False ) for i in range (num_layers )])
317355 self .bn_list = nn .ModuleList (
318356 [nn .ModuleList ([nn .BatchNorm2d (in_channels , momentum = 0.01 , eps = 1e-3 ) for i in range (num_layers )]) for j in
319- range (5 )])
357+ range (pyramid_levels )])
320358 self .header = SeparableConvBlock (in_channels , num_anchors * 4 , norm = False , activation = False )
321359 self .swish = MemoryEfficientSwish () if not onnx_export else Swish ()
322360
@@ -344,7 +382,7 @@ class Classifier(nn.Module):
344382 modified by Zylo117
345383 """
346384
347- def __init__ (self , in_channels , num_anchors , num_classes , num_layers , onnx_export = False ):
385+ def __init__ (self , in_channels , num_anchors , num_classes , num_layers , pyramid_levels = 5 , onnx_export = False ):
348386 super (Classifier , self ).__init__ ()
349387 self .num_anchors = num_anchors
350388 self .num_classes = num_classes
@@ -353,7 +391,7 @@ def __init__(self, in_channels, num_anchors, num_classes, num_layers, onnx_expor
353391 [SeparableConvBlock (in_channels , in_channels , norm = False , activation = False ) for i in range (num_layers )])
354392 self .bn_list = nn .ModuleList (
355393 [nn .ModuleList ([nn .BatchNorm2d (in_channels , momentum = 0.01 , eps = 1e-3 ) for i in range (num_layers )]) for j in
356- range (5 )])
394+ range (pyramid_levels )])
357395 self .header = SeparableConvBlock (in_channels , num_anchors * num_classes , norm = False , activation = False )
358396 self .swish = MemoryEfficientSwish () if not onnx_export else Swish ()
359397
0 commit comments