Skip to content

Commit fa2836c

Browse files
colesburysoumith
authored andcommitted
Add pre-trained VGG models with batch normalization (#178)
Fixes #152
1 parent 83263d8 commit fa2836c

File tree

2 files changed

+69
-33
lines changed

2 files changed

+69
-33
lines changed

torchvision/models/__init__.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,27 +29,31 @@
2929
3030
ImageNet 1-crop error rates (224x224)
3131
32-
======================== ============= =============
33-
Network Top-1 error Top-5 error
34-
======================== ============= =============
35-
ResNet-18 30.24 10.92
36-
ResNet-34 26.70 8.58
37-
ResNet-50 23.85 7.13
38-
ResNet-101 22.63 6.44
39-
ResNet-152 21.69 5.94
40-
Inception v3 22.55 6.44
41-
AlexNet 43.45 20.91
42-
VGG-11 30.98 11.37
43-
VGG-13 30.07 10.75
44-
VGG-16 28.41 9.62
45-
VGG-19 27.62 9.12
46-
SqueezeNet 1.0 41.90 19.58
47-
SqueezeNet 1.1 41.81 19.38
48-
Densenet-121 25.35 7.83
49-
Densenet-169 24.00 7.00
50-
Densenet-201 22.80 6.43
51-
Densenet-161 22.35 6.20
52-
======================== ============= =============
32+
================================ ============= =============
33+
Network Top-1 error Top-5 error
34+
================================ ============= =============
35+
ResNet-18 30.24 10.92
36+
ResNet-34 26.70 8.58
37+
ResNet-50 23.85 7.13
38+
ResNet-101 22.63 6.44
39+
ResNet-152 21.69 5.94
40+
Inception v3 22.55 6.44
41+
AlexNet 43.45 20.91
42+
VGG-11 30.98 11.37
43+
VGG-13 30.07 10.75
44+
VGG-16 28.41 9.62
45+
VGG-19 27.62 9.12
46+
VGG-11 with batch normalization 29.62 10.19
47+
VGG-13 with batch normalization 28.45 9.63
48+
VGG-16 with batch normalization 26.63 8.50
49+
VGG-19 with batch normalization 25.76 8.15
50+
SqueezeNet 1.0 41.90 19.58
51+
SqueezeNet 1.1 41.81 19.38
52+
Densenet-121 25.35 7.83
53+
Densenet-169 24.00 7.00
54+
Densenet-201 22.80 6.43
55+
Densenet-161 22.35 6.20
56+
================================ ============= =============
5357
5458
5559
.. _AlexNet: https://arxiv.org/abs/1404.5997

torchvision/models/vgg.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
1515
'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
1616
'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
17+
'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
18+
'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
19+
'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
20+
'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
1721
}
1822

1923

@@ -91,9 +95,16 @@ def vgg11(pretrained=False, **kwargs):
9195
return model
9296

9397

94-
def vgg11_bn(**kwargs):
95-
"""VGG 11-layer model (configuration "A") with batch normalization"""
96-
return VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
98+
def vgg11_bn(pretrained=False, **kwargs):
99+
"""VGG 11-layer model (configuration "A") with batch normalization
100+
101+
Args:
102+
pretrained (bool): If True, returns a model pre-trained on ImageNet
103+
"""
104+
model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
105+
if pretrained:
106+
model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn']))
107+
return model
97108

98109

99110
def vgg13(pretrained=False, **kwargs):
@@ -108,9 +119,16 @@ def vgg13(pretrained=False, **kwargs):
108119
return model
109120

110121

111-
def vgg13_bn(**kwargs):
112-
"""VGG 13-layer model (configuration "B") with batch normalization"""
113-
return VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
122+
def vgg13_bn(pretrained=False, **kwargs):
123+
"""VGG 13-layer model (configuration "B") with batch normalization
124+
125+
Args:
126+
pretrained (bool): If True, returns a model pre-trained on ImageNet
127+
"""
128+
model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
129+
if pretrained:
130+
model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn']))
131+
return model
114132

115133

116134
def vgg16(pretrained=False, **kwargs):
@@ -125,9 +143,16 @@ def vgg16(pretrained=False, **kwargs):
125143
return model
126144

127145

128-
def vgg16_bn(**kwargs):
129-
"""VGG 16-layer model (configuration "D") with batch normalization"""
130-
return VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
146+
def vgg16_bn(pretrained=False, **kwargs):
147+
"""VGG 16-layer model (configuration "D") with batch normalization
148+
149+
Args:
150+
pretrained (bool): If True, returns a model pre-trained on ImageNet
151+
"""
152+
model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
153+
if pretrained:
154+
model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn']))
155+
return model
131156

132157

133158
def vgg19(pretrained=False, **kwargs):
@@ -142,6 +167,13 @@ def vgg19(pretrained=False, **kwargs):
142167
return model
143168

144169

145-
def vgg19_bn(**kwargs):
146-
"""VGG 19-layer model (configuration 'E') with batch normalization"""
147-
return VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
170+
def vgg19_bn(pretrained=False, **kwargs):
171+
"""VGG 19-layer model (configuration 'E') with batch normalization
172+
173+
Args:
174+
pretrained (bool): If True, returns a model pre-trained on ImageNet
175+
"""
176+
model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
177+
if pretrained:
178+
model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn']))
179+
return model

0 commit comments

Comments
 (0)