forked from AIFengheshu/Plug-play-modules
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path(ECCV 2018) CBAM.py
More file actions
85 lines (74 loc) · 3.23 KB
/
(ECCV 2018) CBAM.py
File metadata and controls
85 lines (74 loc) · 3.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
from torch import nn
from torch.nn import init
# 论文题目:CBAM: Convolutional Block Attention Module
# 中文题目: 卷积块注意力模块
# 论文链接:https://openaccess.thecvf.com/content_ECCV_2018/papers/Sanghyun_Woo_Convolutional_Block_Attention_ECCV_2018_paper.pdf
# 非官方github:https://github.com/luuuyi/CBAM.PyTorch
# 所属机构:韩国科学技术高级研究院,大田,韩国;Lunit公司,首尔,韩国;
# Adobe研究,圣何塞,加利福尼亚州,美国
# 关键词:目标检测,图像分类,注意力机制,门控卷积
# 微信公众号:AI缝合术
class ChannelAttention(nn.Module):
def __init__(self, channel, reduction=16):
super().__init__()
self.maxpool = nn.AdaptiveMaxPool2d(1)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.se = nn.Sequential(
nn.Conv2d(channel, channel // reduction, 1, bias=False),
nn.ReLU(),
nn.Conv2d(channel // reduction, channel, 1, bias=False)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
max_result = self.maxpool(x)
avg_result = self.avgpool(x)
max_out = self.se(max_result)
avg_out = self.se(avg_result)
output = self.sigmoid(max_out + avg_out)
return output
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super().__init__()
assert kernel_size % 2 == 1, "Kernel size must be odd."
self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
max_result, _ = torch.max(x, dim=1, keepdim=True)
avg_result = torch.mean(x, dim=1, keepdim=True)
result = torch.cat([max_result, avg_result], dim=1)
output = self.conv(result)
output = self.sigmoid(output)
return output
class CBAMBlock(nn.Module):
def __init__(self, channel=512, reduction=16, kernel_size=7):
super().__init__()
self.ca = ChannelAttention(channel=channel, reduction=reduction)
self.sa = SpatialAttention(kernel_size=kernel_size)
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, x):
residual = x
out = x * self.ca(x)
out = out * self.sa(out)
return out + residual
if __name__ == '__main__':
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Ensure kernel_size is appropriate for input size
cbam = CBAMBlock(channel=256, reduction=16, kernel_size=7).to(device)
input = torch.rand(1, 256, 64, 64).to(device)
output = cbam(input)
print(f"\nInput shape: {input.shape}")
print(f"Output shape: {output.shape}")