Skip to content

Commit 6f7e26b

Browse files
ssnlfmassa
authored andcommitted
Fix Densenet module keys (#474)
1 parent f6ab107 commit 6f7e26b

File tree

1 file changed

+63
-10
lines changed

1 file changed

+63
-10
lines changed

torchvision/models/densenet.py

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
import torch
23
import torch.nn as nn
34
import torch.nn.functional as F
@@ -25,7 +26,20 @@ def densenet121(pretrained=False, **kwargs):
2526
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
2627
**kwargs)
2728
if pretrained:
28-
model.load_state_dict(model_zoo.load_url(model_urls['densenet121']))
29+
# '.'s are no longer allowed in module names, but pervious _DenseLayer
30+
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
31+
# They are also in the checkpoints in model_urls. This pattern is used
32+
# to find such keys.
33+
pattern = re.compile(
34+
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
35+
state_dict = model_zoo.load_url(model_urls['densenet121'])
36+
for key in list(state_dict.keys()):
37+
res = pattern.match(key)
38+
if res:
39+
new_key = res.group(1) + res.group(2)
40+
state_dict[new_key] = state_dict[key]
41+
del state_dict[key]
42+
model.load_state_dict(state_dict)
2943
return model
3044

3145

@@ -39,7 +53,20 @@ def densenet169(pretrained=False, **kwargs):
3953
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
4054
**kwargs)
4155
if pretrained:
42-
model.load_state_dict(model_zoo.load_url(model_urls['densenet169']))
56+
# '.'s are no longer allowed in module names, but pervious _DenseLayer
57+
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
58+
# They are also in the checkpoints in model_urls. This pattern is used
59+
# to find such keys.
60+
pattern = re.compile(
61+
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
62+
state_dict = model_zoo.load_url(model_urls['densenet169'])
63+
for key in list(state_dict.keys()):
64+
res = pattern.match(key)
65+
if res:
66+
new_key = res.group(1) + res.group(2)
67+
state_dict[new_key] = state_dict[key]
68+
del state_dict[key]
69+
model.load_state_dict(state_dict)
4370
return model
4471

4572

@@ -53,7 +80,20 @@ def densenet201(pretrained=False, **kwargs):
5380
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
5481
**kwargs)
5582
if pretrained:
56-
model.load_state_dict(model_zoo.load_url(model_urls['densenet201']))
83+
# '.'s are no longer allowed in module names, but pervious _DenseLayer
84+
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
85+
# They are also in the checkpoints in model_urls. This pattern is used
86+
# to find such keys.
87+
pattern = re.compile(
88+
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
89+
state_dict = model_zoo.load_url(model_urls['densenet201'])
90+
for key in list(state_dict.keys()):
91+
res = pattern.match(key)
92+
if res:
93+
new_key = res.group(1) + res.group(2)
94+
state_dict[new_key] = state_dict[key]
95+
del state_dict[key]
96+
model.load_state_dict(state_dict)
5797
return model
5898

5999

@@ -67,20 +107,33 @@ def densenet161(pretrained=False, **kwargs):
67107
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
68108
**kwargs)
69109
if pretrained:
70-
model.load_state_dict(model_zoo.load_url(model_urls['densenet161']))
110+
# '.'s are no longer allowed in module names, but pervious _DenseLayer
111+
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
112+
# They are also in the checkpoints in model_urls. This pattern is used
113+
# to find such keys.
114+
pattern = re.compile(
115+
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
116+
state_dict = model_zoo.load_url(model_urls['densenet161'])
117+
for key in list(state_dict.keys()):
118+
res = pattern.match(key)
119+
if res:
120+
new_key = res.group(1) + res.group(2)
121+
state_dict[new_key] = state_dict[key]
122+
del state_dict[key]
123+
model.load_state_dict(state_dict)
71124
return model
72125

73126

74127
class _DenseLayer(nn.Sequential):
75128
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
76129
super(_DenseLayer, self).__init__()
77-
self.add_module('norm.1', nn.BatchNorm2d(num_input_features)),
78-
self.add_module('relu.1', nn.ReLU(inplace=True)),
79-
self.add_module('conv.1', nn.Conv2d(num_input_features, bn_size *
130+
self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
131+
self.add_module('relu1', nn.ReLU(inplace=True)),
132+
self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
80133
growth_rate, kernel_size=1, stride=1, bias=False)),
81-
self.add_module('norm.2', nn.BatchNorm2d(bn_size * growth_rate)),
82-
self.add_module('relu.2', nn.ReLU(inplace=True)),
83-
self.add_module('conv.2', nn.Conv2d(bn_size * growth_rate, growth_rate,
134+
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
135+
self.add_module('relu2', nn.ReLU(inplace=True)),
136+
self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
84137
kernel_size=3, stride=1, padding=1, bias=False)),
85138
self.drop_rate = drop_rate
86139

0 commit comments

Comments
 (0)