forked from AIFengheshu/Plug-play-modules
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path(Elsevier2023) Mult-Collaborative-Attention.py
More file actions
118 lines (91 loc) · 3.62 KB
/
(Elsevier2023) Mult-Collaborative-Attention.py
File metadata and controls
118 lines (91 loc) · 3.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
from torch import nn
import math
# 论文题目:MCA: Multidimensional collaborative attention in deep convolutional neural networks for image recognition
# 中文题目:MCA:用于图像识别的深度卷积神经网络中的多维协作注意力
# 论文链接:https://doi.org/10.1016/j.engappai.2023.107079
# 官方github:https://github.com/ndsclark/MCANet
# 所属机构:西北工业大学电子与信息学院
# 代码整理:微信公众号《AI缝合术》
class StdPool(nn.Module):
def __init__(self):
super(StdPool, self).__init__()
def forward(self, x):
b, c, _, _ = x.size()
std = x.view(b, c, -1).std(dim=2, keepdim=True)
std = std.reshape(b, c, 1, 1)
return std
class MCAGate(nn.Module):
def __init__(self, k_size, pool_types=['avg', 'std']):
"""Constructs a MCAGate module.
Args:
k_size: kernel size
pool_types: pooling type. 'avg': average pooling, 'max': max pooling, 'std': standard deviation pooling.
"""
super(MCAGate, self).__init__()
self.pools = nn.ModuleList([])
for pool_type in pool_types:
if pool_type == 'avg':
self.pools.append(nn.AdaptiveAvgPool2d(1))
elif pool_type == 'max':
self.pools.append(nn.AdaptiveMaxPool2d(1))
elif pool_type == 'std':
self.pools.append(StdPool())
else:
raise NotImplementedError
self.conv = nn.Conv2d(1, 1, kernel_size=(1, k_size), stride=1, padding=(0, (k_size - 1) // 2), bias=False)
self.sigmoid = nn.Sigmoid()
self.weight = nn.Parameter(torch.rand(2))
def forward(self, x):
feats = [pool(x) for pool in self.pools]
if len(feats) == 1:
out = feats[0]
elif len(feats) == 2:
weight = torch.sigmoid(self.weight)
out = 1 / 2 * (feats[0] + feats[1]) + weight[0] * feats[0] + weight[1] * feats[1]
else:
assert False, "Feature Extraction Exception!"
out = out.permute(0, 3, 2, 1).contiguous()
out = self.conv(out)
out = out.permute(0, 3, 2, 1).contiguous()
out = self.sigmoid(out)
out = out.expand_as(x)
return x * out
class MCALayer(nn.Module):
def __init__(self, inp, no_spatial=False):
"""Constructs a MCA module.
Args:
inp: Number of channels of the input feature maps
no_spatial: whether to build channel dimension interactions
"""
super(MCALayer, self).__init__()
lambd = 1.5
gamma = 1
temp = round(abs((math.log2(inp) - gamma) / lambd))
kernel = temp if temp % 2 else temp - 1
self.h_cw = MCAGate(3)
self.w_hc = MCAGate(3)
self.no_spatial = no_spatial
if not no_spatial:
self.c_hw = MCAGate(kernel)
def forward(self, x):
x_h = x.permute(0, 2, 1, 3).contiguous()
x_h = self.h_cw(x_h)
x_h = x_h.permute(0, 2, 1, 3).contiguous()
x_w = x.permute(0, 3, 2, 1).contiguous()
x_w = self.w_hc(x_w)
x_w = x_w.permute(0, 3, 2, 1).contiguous()
if not self.no_spatial:
x_c = self.c_hw(x)
x_out = 1 / 3 * (x_c + x_h + x_w)
else:
x_out = 1 / 2 * (x_h + x_w)
return x_out
if __name__ == '__main__':
# 生成随机输入数据
input_data = torch.randn(1, 32, 256, 256)
mca = MCALayer(32)
output = mca(input_data)
# 打印输入和输出形状
print("Input size:", input_data.size())
print("Output size:", output.size())