forked from AIFengheshu/Plug-play-modules
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path(CVPR 2024) StarNet.py
More file actions
166 lines (140 loc) · 6.37 KB
/
(CVPR 2024) StarNet.py
File metadata and controls
166 lines (140 loc) · 6.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import torch
import torch.nn as nn
from timm.models.layers import DropPath, trunc_normal_
from timm.models.registry import register_model
from torchsummary import summary
# 论文题目:Rewrite the Stars
# 中文题目: 重写星操作
# 论文链接:https://arxiv.org/pdf/2403.19967
# 官方github:https://github.com/ma-xu/Rewrite-the-Stars
# 所属机构:东北大学,微软
# 关键词:星操作、网络设计、StarNet、高效网络、核技巧
# 微信公众号:AI缝合术
model_urls = {
"starnet_s1": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s1.pth.tar",
"starnet_s2": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s2.pth.tar",
"starnet_s3": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s3.pth.tar",
"starnet_s4": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s4.pth.tar",
}
class ConvBN(torch.nn.Sequential):
def __init__(self, in_planes, out_planes, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, with_bn=True):
super().__init__()
self.add_module('conv', torch.nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation, groups))
if with_bn:
self.add_module('bn', torch.nn.BatchNorm2d(out_planes))
torch.nn.init.constant_(self.bn.weight, 1)
torch.nn.init.constant_(self.bn.bias, 0)
class Block(nn.Module):
def __init__(self, dim, mlp_ratio=3, drop_path=0.):
super().__init__()
self.dwconv = ConvBN(dim, dim, 7, 1, (7 - 1) // 2, groups=dim, with_bn=True)
self.f1 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False)
self.f2 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False)
self.g = ConvBN(mlp_ratio * dim, dim, 1, with_bn=True)
self.dwconv2 = ConvBN(dim, dim, 7, 1, (7 - 1) // 2, groups=dim, with_bn=False)
self.act = nn.ReLU6()
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
input = x
x = self.dwconv(x)
x1, x2 = self.f1(x), self.f2(x)
x = self.act(x1) * x2
x = self.dwconv2(self.g(x))
x = input + self.drop_path(x)
return x
class StarNet(nn.Module):
def __init__(self, base_dim=32, depths=[3, 3, 12, 5], mlp_ratio=4, drop_path_rate=0.0, num_classes=1000, **kwargs):
super().__init__()
self.num_classes = num_classes
self.in_channel = 32
# stem layer
self.stem = nn.Sequential(ConvBN(3, self.in_channel, kernel_size=3, stride=2, padding=1), nn.ReLU6())
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth
# build stages
self.stages = nn.ModuleList()
cur = 0
for i_layer in range(len(depths)):
embed_dim = base_dim * 2 ** i_layer
down_sampler = ConvBN(self.in_channel, embed_dim, 3, 2, 1)
self.in_channel = embed_dim
blocks = [Block(self.in_channel, mlp_ratio, dpr[cur + i]) for i in range(depths[i_layer])]
cur += depths[i_layer]
self.stages.append(nn.Sequential(down_sampler, *blocks))
# head
self.norm = nn.BatchNorm2d(self.in_channel)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.head = nn.Linear(self.in_channel, num_classes)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear or nn.Conv2d):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm or nn.BatchNorm2d):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
x = self.stem(x)
for stage in self.stages:
x = stage(x)
x = torch.flatten(self.avgpool(self.norm(x)), 1)
return self.head(x)
@register_model
def starnet_s1(pretrained=False, **kwargs):
model = StarNet(24, [2, 2, 8, 3], **kwargs)
if pretrained:
url = model_urls['starnet_s1']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
model.load_state_dict(checkpoint["state_dict"])
return model
@register_model
def starnet_s2(pretrained=False, **kwargs):
model = StarNet(32, [1, 2, 6, 2], **kwargs)
if pretrained:
url = model_urls['starnet_s2']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
model.load_state_dict(checkpoint["state_dict"])
return model
@register_model
def starnet_s3(pretrained=False, **kwargs):
model = StarNet(32, [2, 2, 8, 4], **kwargs)
if pretrained:
url = model_urls['starnet_s3']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
model.load_state_dict(checkpoint["state_dict"])
return model
@register_model
def starnet_s4(pretrained=False, **kwargs):
model = StarNet(32, [3, 3, 12, 5], **kwargs)
if pretrained:
url = model_urls['starnet_s4']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
model.load_state_dict(checkpoint["state_dict"])
return model
# very small networks #
@register_model
def starnet_s050(pretrained=False, **kwargs):
return StarNet(16, [1, 1, 3, 1], 3, **kwargs)
@register_model
def starnet_s100(pretrained=False, **kwargs):
return StarNet(20, [1, 2, 4, 1], 4, **kwargs)
@register_model
def starnet_s150(pretrained=False, **kwargs):
return StarNet(24, [1, 2, 4, 2], 3, **kwargs)
if __name__ == "__main__":
# 配置设备(CPU 或 GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 加载 starnet_s1 模型
model = starnet_s1(pretrained=False, num_classes=1000).to(device)
print("Model loaded successfully.")
# 模拟输入数据(batch_size=1, 3通道图像,大小为224x224)
input_tensor = torch.randn(1, 3, 224, 224).to(device)
# 模型推理
with torch.no_grad(): # 关闭梯度计算
model.eval() # 设置模型为评估模式
output = model(input_tensor)
# 打印输出结果形状
print(f"Output shape: {output.shape}")
# 打印模型摘要(可选)
summary(model, input_size=(3, 224, 224))