forked from AIFengheshu/Plug-play-modules
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path(AAAI 2025) MWSAttention.py
More file actions
181 lines (156 loc) · 7.34 KB
/
(AAAI 2025) MWSAttention.py
File metadata and controls
181 lines (156 loc) · 7.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import torch
import torch.nn as nn
from einops import rearrange
from torch import einsum
# 论文题目:Rethinking Transformer-Based Blind-Spot Network for Self-Supervised Image Denoising
# 中文题目:基于Transformer的盲点网络自监督图像去噪的再思考
# 论文链接:https://arxiv.org/pdf/2404.07846
# 官方github:https://github.com/nagejacob/TBSN
# 所属机构:哈尔滨工业大学
# 代码整理:微信公众号《AI缝合术》
def to(x):
return {'device': x.device, 'dtype': x.dtype}
def expand_dim(t, dim, k):
t = t.unsqueeze(dim = dim)
expand_shape = [-1] * len(t.shape)
expand_shape[dim] = k
return t.expand(*expand_shape)
def rel_to_abs(x):
b, l, m = x.shape
r = (m + 1) // 2
col_pad = torch.zeros((b, l, 1), **to(x))
x = torch.cat((x, col_pad), dim = 2)
flat_x = rearrange(x, 'b l c -> b (l c)')
flat_pad = torch.zeros((b, m - l), **to(x))
flat_x_padded = torch.cat((flat_x, flat_pad), dim = 1)
final_x = flat_x_padded.reshape(b, l + 1, m)
final_x = final_x[:, :l, -r:]
return final_x
def relative_logits_1d(q, rel_k):
b, h, w, _ = q.shape
r = (rel_k.shape[0] + 1) // 2
logits = einsum('b x y d, r d -> b x y r', q, rel_k)
logits = rearrange(logits, 'b x y r -> (b x) y r')
logits = rel_to_abs(logits)
logits = logits.reshape(b, h, w, r)
logits = expand_dim(logits, dim = 2, k = r)
return logits
class RelPosEmb(nn.Module):
def __init__(
self,
block_size,
rel_size,
dim_head
):
super().__init__()
height = width = rel_size
scale = dim_head ** -0.5
self.block_size = block_size
self.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale)
self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale)
def forward(self, q):
block = self.block_size
q = rearrange(q, 'b (x y) c -> b x y c', x = block)
rel_logits_w = relative_logits_1d(q, self.rel_width)
rel_logits_w = rearrange(rel_logits_w, 'b x i y j-> b (x y) (i j)')
q = rearrange(q, 'b x y d -> b y x d')
rel_logits_h = relative_logits_1d(q, self.rel_height)
rel_logits_h = rearrange(rel_logits_h, 'b x i y j -> b (y x) (j i)')
return rel_logits_w + rel_logits_h
class FixedPosEmb(nn.Module):
def __init__(self, window_size, overlap_window_size):
super().__init__()
self.window_size = window_size
self.overlap_window_size = overlap_window_size
attention_mask_table = torch.zeros((window_size + overlap_window_size - 1), (window_size + overlap_window_size - 1))
attention_mask_table[0::2, :] = float('-inf')
attention_mask_table[:, 0::2] = float('-inf')
attention_mask_table = attention_mask_table.view((window_size + overlap_window_size - 1) * (window_size + overlap_window_size - 1))
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size)
coords_w = torch.arange(self.window_size)
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten_1 = torch.flatten(coords, 1) # 2, Wh*Ww
coords_h = torch.arange(self.overlap_window_size)
coords_w = torch.arange(self.overlap_window_size)
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
coords_flatten_2 = torch.flatten(coords, 1)
relative_coords = coords_flatten_1[:, :, None] - coords_flatten_2[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.overlap_window_size - 1 # shift to start from 0
relative_coords[:, :, 1] += self.overlap_window_size - 1
relative_coords[:, :, 0] *= self.window_size + self.overlap_window_size - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.attention_mask = nn.Parameter(attention_mask_table[relative_position_index.view(-1)].view(
1, self.window_size ** 2, self.overlap_window_size ** 2
), requires_grad=False)
def forward(self):
return self.attention_mask
class MWSAttention(nn.Module):
def __init__(self, dim, window_size, overlap_ratio, num_heads, dim_head, bias):
super(MWSAttention, self).__init__()
self.num_spatial_heads = num_heads
self.dim = dim
self.window_size = window_size
self.overlap_win_size = int(window_size * overlap_ratio) + window_size
self.dim_head = dim_head
self.inner_dim = self.dim_head * self.num_spatial_heads
self.scale = self.dim_head**-0.5
self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size), stride=window_size, padding=(self.overlap_win_size-window_size)//2)
self.qkv = nn.Conv2d(self.dim, self.inner_dim*3, kernel_size=1, bias=bias)
self.project_out = nn.Conv2d(self.inner_dim, dim, kernel_size=1, bias=bias)
self.rel_pos_emb = RelPosEmb(
block_size = window_size,
rel_size = window_size + (self.overlap_win_size - window_size),
dim_head = self.dim_head
)
self.fixed_pos_emb = FixedPosEmb(window_size, self.overlap_win_size)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.qkv(x)
qs, ks, vs = qkv.chunk(3, dim=1)
# spatial attention
qs = rearrange(qs, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1 = self.window_size, p2 = self.window_size)
ks, vs = map(lambda t: self.unfold(t), (ks, vs))
ks, vs = map(lambda t: rearrange(t, 'b (c j) i -> (b i) j c', c = self.inner_dim), (ks, vs))
# print(f'qs.shape:{qs.shape}, ks.shape:{ks.shape}, vs.shape:{vs.shape}')
#split heads
qs, ks, vs = map(lambda t: rearrange(t, 'b n (head c) -> (b head) n c', head = self.num_spatial_heads), (qs, ks, vs))
# attention
qs = qs * self.scale
spatial_attn = (qs @ ks.transpose(-2, -1))
spatial_attn += self.rel_pos_emb(qs)
spatial_attn += self.fixed_pos_emb()
spatial_attn = spatial_attn.softmax(dim=-1)
out = (spatial_attn @ vs)
out = rearrange(out, '(b h w head) (p1 p2) c -> b (head c) (h p1) (w p2)', head = self.num_spatial_heads, h = h // self.window_size, w = w // self.window_size, p1 = self.window_size, p2 = self.window_size)
# merge spatial and channel
out = self.project_out(out)
return out
if __name__ == "__main__":
# 参数设置
dim = 32
window_size = 8
overlap_ratio = 0.5
num_heads = 2
dim_head = 16
# 初始化 MWSAttention 模块
mws_attention = MWSAttention(
dim=dim,
window_size=window_size,
overlap_ratio=overlap_ratio,
num_heads=num_heads,
dim_head=dim_head,
bias=True
)
# 将模块移动到 GPU(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mws_attention = mws_attention.to(device)
print(mws_attention)
# 创建测试输入张量
x = torch.randn(1, 32, 256, 256).to(device) # 假设输入为 1x32x256x256 的特征图
# 前向传播
output = mws_attention(x)
# 打印输出张量的形状
print("输入张量形状:", x.shape)
print("输出张量形状:", output.shape)