2
2
import torch .nn as nn
3
3
import torch .nn .functional as F
4
4
5
- from ..common .blocks import Conv2dReLU
5
+ from ..common .blocks import Conv2dReLU , SCSEModule
6
6
from ..base .model import Model
7
7
8
8
9
9
class DecoderBlock (nn .Module ):
10
- def __init__ (self , in_channels , out_channels , use_batchnorm = True ):
10
+ def __init__ (self , in_channels , out_channels , use_batchnorm = True , attention_type = None ):
11
11
super ().__init__ ()
12
+ if attention_type is None :
13
+ self .attention1 = nn .Identity ()
14
+ self .attention2 = nn .Identity ()
15
+ elif attention_type == 'scse' :
16
+ self .attention1 = SCSEModule (in_channels )
17
+ self .attention2 = SCSEModule (out_channels )
18
+
12
19
self .block = nn .Sequential (
13
20
Conv2dReLU (in_channels , out_channels , kernel_size = 3 , padding = 1 , use_batchnorm = use_batchnorm ),
14
21
Conv2dReLU (out_channels , out_channels , kernel_size = 3 , padding = 1 , use_batchnorm = use_batchnorm ),
@@ -19,7 +26,10 @@ def forward(self, x):
19
26
x = F .interpolate (x , scale_factor = 2 , mode = 'nearest' )
20
27
if skip is not None :
21
28
x = torch .cat ([x , skip ], dim = 1 )
29
+ x = self .attention1 (x )
30
+
22
31
x = self .block (x )
32
+ x = self .attention2 (x )
23
33
return x
24
34
25
35
@@ -38,6 +48,7 @@ def __init__(
38
48
final_channels = 1 ,
39
49
use_batchnorm = True ,
40
50
center = False ,
51
+ attention_type = None
41
52
):
42
53
super ().__init__ ()
43
54
@@ -50,11 +61,16 @@ def __init__(
50
61
in_channels = self .compute_channels (encoder_channels , decoder_channels )
51
62
out_channels = decoder_channels
52
63
53
- self .layer1 = DecoderBlock (in_channels [0 ], out_channels [0 ], use_batchnorm = use_batchnorm )
54
- self .layer2 = DecoderBlock (in_channels [1 ], out_channels [1 ], use_batchnorm = use_batchnorm )
55
- self .layer3 = DecoderBlock (in_channels [2 ], out_channels [2 ], use_batchnorm = use_batchnorm )
56
- self .layer4 = DecoderBlock (in_channels [3 ], out_channels [3 ], use_batchnorm = use_batchnorm )
57
- self .layer5 = DecoderBlock (in_channels [4 ], out_channels [4 ], use_batchnorm = use_batchnorm )
64
+ self .layer1 = DecoderBlock (in_channels [0 ], out_channels [0 ],
65
+ use_batchnorm = use_batchnorm , attention_type = attention_type )
66
+ self .layer2 = DecoderBlock (in_channels [1 ], out_channels [1 ],
67
+ use_batchnorm = use_batchnorm , attention_type = attention_type )
68
+ self .layer3 = DecoderBlock (in_channels [2 ], out_channels [2 ],
69
+ use_batchnorm = use_batchnorm , attention_type = attention_type )
70
+ self .layer4 = DecoderBlock (in_channels [3 ], out_channels [3 ],
71
+ use_batchnorm = use_batchnorm , attention_type = attention_type )
72
+ self .layer5 = DecoderBlock (in_channels [4 ], out_channels [4 ],
73
+ use_batchnorm = use_batchnorm , attention_type = attention_type )
58
74
self .final_conv = nn .Conv2d (out_channels [4 ], final_channels , kernel_size = (1 , 1 ))
59
75
60
76
self .initialize ()
0 commit comments