1
+ import re
1
2
import torch
2
3
import torch .nn as nn
3
4
import torch .nn .functional as F
@@ -25,7 +26,20 @@ def densenet121(pretrained=False, **kwargs):
25
26
model = DenseNet (num_init_features = 64 , growth_rate = 32 , block_config = (6 , 12 , 24 , 16 ),
26
27
** kwargs )
27
28
if pretrained :
28
- model .load_state_dict (model_zoo .load_url (model_urls ['densenet121' ]))
29
+ # '.'s are no longer allowed in module names, but pervious _DenseLayer
30
+ # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
31
+ # They are also in the checkpoints in model_urls. This pattern is used
32
+ # to find such keys.
33
+ pattern = re .compile (
34
+ r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$' )
35
+ state_dict = model_zoo .load_url (model_urls ['densenet121' ])
36
+ for key in list (state_dict .keys ()):
37
+ res = pattern .match (key )
38
+ if res :
39
+ new_key = res .group (1 ) + res .group (2 )
40
+ state_dict [new_key ] = state_dict [key ]
41
+ del state_dict [key ]
42
+ model .load_state_dict (state_dict )
29
43
return model
30
44
31
45
@@ -39,7 +53,20 @@ def densenet169(pretrained=False, **kwargs):
39
53
model = DenseNet (num_init_features = 64 , growth_rate = 32 , block_config = (6 , 12 , 32 , 32 ),
40
54
** kwargs )
41
55
if pretrained :
42
- model .load_state_dict (model_zoo .load_url (model_urls ['densenet169' ]))
56
+ # '.'s are no longer allowed in module names, but pervious _DenseLayer
57
+ # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
58
+ # They are also in the checkpoints in model_urls. This pattern is used
59
+ # to find such keys.
60
+ pattern = re .compile (
61
+ r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$' )
62
+ state_dict = model_zoo .load_url (model_urls ['densenet169' ])
63
+ for key in list (state_dict .keys ()):
64
+ res = pattern .match (key )
65
+ if res :
66
+ new_key = res .group (1 ) + res .group (2 )
67
+ state_dict [new_key ] = state_dict [key ]
68
+ del state_dict [key ]
69
+ model .load_state_dict (state_dict )
43
70
return model
44
71
45
72
@@ -53,7 +80,20 @@ def densenet201(pretrained=False, **kwargs):
53
80
model = DenseNet (num_init_features = 64 , growth_rate = 32 , block_config = (6 , 12 , 48 , 32 ),
54
81
** kwargs )
55
82
if pretrained :
56
- model .load_state_dict (model_zoo .load_url (model_urls ['densenet201' ]))
83
+ # '.'s are no longer allowed in module names, but pervious _DenseLayer
84
+ # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
85
+ # They are also in the checkpoints in model_urls. This pattern is used
86
+ # to find such keys.
87
+ pattern = re .compile (
88
+ r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$' )
89
+ state_dict = model_zoo .load_url (model_urls ['densenet201' ])
90
+ for key in list (state_dict .keys ()):
91
+ res = pattern .match (key )
92
+ if res :
93
+ new_key = res .group (1 ) + res .group (2 )
94
+ state_dict [new_key ] = state_dict [key ]
95
+ del state_dict [key ]
96
+ model .load_state_dict (state_dict )
57
97
return model
58
98
59
99
@@ -67,20 +107,33 @@ def densenet161(pretrained=False, **kwargs):
67
107
model = DenseNet (num_init_features = 96 , growth_rate = 48 , block_config = (6 , 12 , 36 , 24 ),
68
108
** kwargs )
69
109
if pretrained :
70
- model .load_state_dict (model_zoo .load_url (model_urls ['densenet161' ]))
110
+ # '.'s are no longer allowed in module names, but pervious _DenseLayer
111
+ # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
112
+ # They are also in the checkpoints in model_urls. This pattern is used
113
+ # to find such keys.
114
+ pattern = re .compile (
115
+ r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$' )
116
+ state_dict = model_zoo .load_url (model_urls ['densenet161' ])
117
+ for key in list (state_dict .keys ()):
118
+ res = pattern .match (key )
119
+ if res :
120
+ new_key = res .group (1 ) + res .group (2 )
121
+ state_dict [new_key ] = state_dict [key ]
122
+ del state_dict [key ]
123
+ model .load_state_dict (state_dict )
71
124
return model
72
125
73
126
74
127
class _DenseLayer (nn .Sequential ):
75
128
def __init__ (self , num_input_features , growth_rate , bn_size , drop_rate ):
76
129
super (_DenseLayer , self ).__init__ ()
77
- self .add_module ('norm.1 ' , nn .BatchNorm2d (num_input_features )),
78
- self .add_module ('relu.1 ' , nn .ReLU (inplace = True )),
79
- self .add_module ('conv.1 ' , nn .Conv2d (num_input_features , bn_size *
130
+ self .add_module ('norm1 ' , nn .BatchNorm2d (num_input_features )),
131
+ self .add_module ('relu1 ' , nn .ReLU (inplace = True )),
132
+ self .add_module ('conv1 ' , nn .Conv2d (num_input_features , bn_size *
80
133
growth_rate , kernel_size = 1 , stride = 1 , bias = False )),
81
- self .add_module ('norm.2 ' , nn .BatchNorm2d (bn_size * growth_rate )),
82
- self .add_module ('relu.2 ' , nn .ReLU (inplace = True )),
83
- self .add_module ('conv.2 ' , nn .Conv2d (bn_size * growth_rate , growth_rate ,
134
+ self .add_module ('norm2 ' , nn .BatchNorm2d (bn_size * growth_rate )),
135
+ self .add_module ('relu2 ' , nn .ReLU (inplace = True )),
136
+ self .add_module ('conv2 ' , nn .Conv2d (bn_size * growth_rate , growth_rate ,
84
137
kernel_size = 3 , stride = 1 , padding = 1 , bias = False )),
85
138
self .drop_rate = drop_rate
86
139
0 commit comments