2
2
import torch .nn as nn
3
3
import torch .nn .functional as F
4
4
5
- from ..common .blocks import Conv2dReLU
6
5
from ..base .model import Model
7
6
8
7
8
+ class Conv3x3GNReLU (nn .Module ):
9
+ def __init__ (self , in_channels , out_channels , upsample = False ):
10
+
11
+ super ().__init__ ()
12
+ self .upsample = upsample
13
+ self .block = nn .Sequential (
14
+ nn .Conv2d (in_channels , out_channels , (3 , 3 ),
15
+ stride = 1 , padding = 1 , bias = False ),
16
+ nn .GroupNorm (32 , out_channels ),
17
+ nn .ReLU (inplace = True ),
18
+ )
19
+
20
+ def forward (self , x ):
21
+ x = self .block (x )
22
+ if self .upsample :
23
+ x = F .interpolate (x , scale_factor = 2 , mode = 'bilinear' , align_corners = True )
24
+ return x
25
+
26
+
27
+
9
28
class FPNBlock (nn .Module ):
10
29
def __init__ (self , pyramid_channels , skip_channels ):
11
30
super ().__init__ ()
@@ -22,12 +41,18 @@ def forward(self, x):
22
41
23
42
24
43
class SegmentationBlock (nn .Module ):
25
- def __init__ (self , in_channels , out_channels , use_batchnorm = True ):
44
+ def __init__ (self , in_channels , out_channels , n_upsamples = 0 ):
26
45
super ().__init__ ()
27
- self .block = nn .Sequential (
28
- Conv2dReLU (in_channels , out_channels , kernel_size = 3 , padding = 1 , use_batchnorm = use_batchnorm ),
29
- Conv2dReLU (out_channels , out_channels , kernel_size = 3 , padding = 1 , use_batchnorm = use_batchnorm ),
30
- )
46
+
47
+ blocks = [
48
+ Conv3x3GNReLU (in_channels , out_channels , upsample = bool (n_upsamples ))
49
+ ]
50
+
51
+ if n_upsamples > 1 :
52
+ for _ in range (1 , n_upsamples ):
53
+ blocks .append (Conv3x3GNReLU (out_channels , out_channels , upsample = True ))
54
+
55
+ self .block = nn .Sequential (* blocks )
31
56
32
57
def forward (self , x ):
33
58
return self .block (x )
@@ -42,7 +67,6 @@ def __init__(
42
67
segmentation_channels = 128 ,
43
68
final_channels = 1 ,
44
69
dropout = 0.2 ,
45
- use_batchnorm = True ,
46
70
):
47
71
super ().__init__ ()
48
72
@@ -52,13 +76,13 @@ def __init__(
52
76
self .p3 = FPNBlock (pyramid_channels , encoder_channels [2 ])
53
77
self .p2 = FPNBlock (pyramid_channels , encoder_channels [3 ])
54
78
55
- self .s5 = SegmentationBlock (pyramid_channels , segmentation_channels , use_batchnorm = use_batchnorm )
56
- self .s4 = SegmentationBlock (pyramid_channels , segmentation_channels , use_batchnorm = use_batchnorm )
57
- self .s3 = SegmentationBlock (pyramid_channels , segmentation_channels , use_batchnorm = use_batchnorm )
58
- self .s2 = SegmentationBlock (pyramid_channels , segmentation_channels , use_batchnorm = use_batchnorm )
79
+ self .s5 = SegmentationBlock (pyramid_channels , segmentation_channels , n_upsamples = 3 )
80
+ self .s4 = SegmentationBlock (pyramid_channels , segmentation_channels , n_upsamples = 2 )
81
+ self .s3 = SegmentationBlock (pyramid_channels , segmentation_channels , n_upsamples = 1 )
82
+ self .s2 = SegmentationBlock (pyramid_channels , segmentation_channels , n_upsamples = 0 )
59
83
60
84
self .dropout = nn .Dropout2d (p = dropout , inplace = True )
61
- self .final_conv = nn .Conv2d (4 * segmentation_channels , final_channels , kernel_size = 3 , padding = 1 )
85
+ self .final_conv = nn .Conv2d (segmentation_channels , final_channels , kernel_size = 1 , padding = 0 )
62
86
63
87
self .initialize ()
64
88
@@ -75,12 +99,7 @@ def forward(self, x):
75
99
s3 = self .s3 (p3 )
76
100
s2 = self .s2 (p2 )
77
101
78
- x = torch .cat ([
79
- F .interpolate (s5 , scale_factor = 8 , mode = 'bilinear' , align_corners = True ),
80
- F .interpolate (s4 , scale_factor = 4 , mode = 'bilinear' , align_corners = True ),
81
- F .interpolate (s3 , scale_factor = 2 , mode = 'bilinear' , align_corners = True ),
82
- s2 ,
83
- ], dim = 1 )
102
+ x = s5 + s4 + s3 + s2
84
103
85
104
x = self .dropout (x )
86
105
x = self .final_conv (x )
0 commit comments