forked from AIFengheshu/Plug-play-modules
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path(2023) MLCA.py
More file actions
67 lines (52 loc) · 2.98 KB
/
(2023) MLCA.py
File metadata and controls
67 lines (52 loc) · 2.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
# 论文题目:Mixed local channel attention for object detection
# 中文题目:混合局部通道注意力用于目标检测
# 论文链接:https://doi.org/10.1016/j.engappai.2023.106442
# 官方github:https://github.com/wandahangFY/MLCA
# 所属机构:合肥工业大学仪器科学与光电工程学院,合肥工业大学安徽省测量理论与精密仪器重点实验室
# 代码整理:微信公众号《AI缝合术》
class MLCA(nn.Module):
def __init__(self, in_size,local_size=5,gamma = 2, b = 1,local_weight=0.5):
super(MLCA, self).__init__()
# ECA 计算方法
self.local_size=local_size
self.gamma = gamma
self.b = b
t = int(abs(math.log(in_size, 2) + self.b) / self.gamma) # eca gamma=2
k = t if t % 2 else t + 1
self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
self.conv_local = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
self.local_weight=local_weight
self.local_arv_pool = nn.AdaptiveAvgPool2d(local_size)
self.global_arv_pool=nn.AdaptiveAvgPool2d(1)
def forward(self, x):
local_arv=self.local_arv_pool(x)
global_arv=self.global_arv_pool(local_arv)
b,c,m,n = x.shape
b_local, c_local, m_local, n_local = local_arv.shape
# (b,c,local_size,local_size) -> (b,c,local_size*local_size)-> (b,local_size*local_size,c)-> (b,1,local_size*local_size*c)
temp_local= local_arv.view(b, c_local, -1).transpose(-1, -2).reshape(b, 1, -1)
temp_global = global_arv.view(b, c, -1).transpose(-1, -2)
y_local = self.conv_local(temp_local)
y_global = self.conv(temp_global)
# (b,c,local_size,local_size) <- (b,c,local_size*local_size)<-(b,local_size*local_size,c) <- (b,1,local_size*local_size*c)
y_local_transpose=y_local.reshape(b, self.local_size * self.local_size,c).transpose(-1,-2).view(b,c, self.local_size , self.local_size)
# y_global_transpose = y_global.view(b, -1).transpose(-1, -2).unsqueeze(-1)
y_global_transpose = y_global.view(b, -1).unsqueeze(-1).unsqueeze(-1) # 代码修正
# 反池化
att_local = y_local_transpose.sigmoid()
att_global = F.adaptive_avg_pool2d(y_global_transpose.sigmoid(),[self.local_size, self.local_size])
att_all = F.adaptive_avg_pool2d(att_global*(1-self.local_weight)+(att_local*self.local_weight), [m, n])
x=x*att_all
return x
if __name__ == '__main__':
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_tensor = torch.rand(1, 32, 256, 256).to(device)
attention_module = MLCA(32).to(device)
output_tensor = attention_module(input_tensor)
# 打印输入和输出的形状
print(f"输入张量形状: {input_tensor.shape}")
print(f"输出张量形状: {output_tensor.shape}")