forked from AIFengheshu/Plug-play-modules
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path(CVPR 2023) FSAS.py
More file actions
114 lines (90 loc) · 4.45 KB
/
(CVPR 2023) FSAS.py
File metadata and controls
114 lines (90 loc) · 4.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
import torch.nn as nn
import numbers
from einops import rearrange
import time
# 论文题目:Efficient Frequency Domain-based Transformers for High-Quality Image Deblurring
# 中文题目:基于频域的高效Transformer用于高质量图像去模糊
# 论文链接:https://openaccess.thecvf.com/content/CVPR2023/papers/Kong_Efficient_Frequency_Domain-Based_Transformers_for_High-Quality_Image_Deblurring_CVPR_2023_paper.pdf
# 官方github:https://github.com/kkkls/FFTformer
# 所属机构:南京理工大学计算机科学与工程学院,中国电子科技集团信息科学研究院
# 代码整理:微信公众号《AI缝合术》
def to_3d(x):
return rearrange(x, 'b c h w -> b (h w) c')
def to_4d(x, h, w):
return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
class BiasFree_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(BiasFree_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
sigma = x.var(-1, keepdim=True, unbiased=False)
return x / torch.sqrt(sigma + 1e-5) * self.weight
class WithBias_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(WithBias_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
mu = x.mean(-1, keepdim=True)
sigma = x.var(-1, keepdim=True, unbiased=False)
return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
class LayerNorm(nn.Module):
def __init__(self, dim, LayerNorm_type):
super(LayerNorm, self).__init__()
if LayerNorm_type == 'BiasFree':
self.body = BiasFree_LayerNorm(dim)
else:
self.body = WithBias_LayerNorm(dim)
def forward(self, x):
h, w = x.shape[-2:]
return to_4d(self.body(to_3d(x)), h, w)
class FSAS(nn.Module):
def __init__(self, dim, bias=False):
super(FSAS, self).__init__()
self.to_hidden = nn.Conv2d(dim, dim * 6, kernel_size=1, bias=bias)
self.to_hidden_dw = nn.Conv2d(dim * 6, dim * 6, kernel_size=3, stride=1, padding=1, groups=dim * 6, bias=bias)
self.project_out = nn.Conv2d(dim * 2, dim, kernel_size=1, bias=bias)
self.norm = LayerNorm(dim * 2, LayerNorm_type='WithBias')
self.patch_size = 8
def forward(self, x):
hidden = self.to_hidden(x)
q, k, v = self.to_hidden_dw(hidden).chunk(3, dim=1)
q_patch = rearrange(q, 'b c (h patch1) (w patch2) -> b c h w patch1 patch2', patch1=self.patch_size,
patch2=self.patch_size)
k_patch = rearrange(k, 'b c (h patch1) (w patch2) -> b c h w patch1 patch2', patch1=self.patch_size,
patch2=self.patch_size)
q_fft = torch.fft.rfft2(q_patch.float())
k_fft = torch.fft.rfft2(k_patch.float())
out = q_fft * k_fft
out = torch.fft.irfft2(out, s=(self.patch_size, self.patch_size))
out = rearrange(out, 'b c h w patch1 patch2 -> b c (h patch1) (w patch2)', patch1=self.patch_size,
patch2=self.patch_size)
out = self.norm(out)
output = v * out
output = self.project_out(output)
return output
if __name__ == '__main__':
####可注释####
print("当前系统时间:", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"PyTorch 版本: {torch.__version__}")
print(f"CUDA 版本: {torch.version.cuda}")
print(f"CUDA 是否可用: {torch.cuda.is_available()}")
print("微信公众号:AI缝合术,The test was successful!")
####可注释####
fsas= FSAS(32).to(device)
input = torch.rand(1, 32, 256, 256).to(device) # 输入张量
output = fsas(input) # 前向传播
print(f"\n输入张量形状: {input.shape}")
print(f"输出张量形状: {output.shape}")