1+ import re
12import torch
23import torch .nn as nn
34import torch .nn .functional as F
@@ -25,7 +26,20 @@ def densenet121(pretrained=False, **kwargs):
2526 model = DenseNet (num_init_features = 64 , growth_rate = 32 , block_config = (6 , 12 , 24 , 16 ),
2627 ** kwargs )
2728 if pretrained :
28- model .load_state_dict (model_zoo .load_url (model_urls ['densenet121' ]))
29+ # '.'s are no longer allowed in module names, but pervious _DenseLayer
30+ # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
31+ # They are also in the checkpoints in model_urls. This pattern is used
32+ # to find such keys.
33+ pattern = re .compile (
34+ r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$' )
35+ state_dict = model_zoo .load_url (model_urls ['densenet121' ])
36+ for key in list (state_dict .keys ()):
37+ res = pattern .match (key )
38+ if res :
39+ new_key = res .group (1 ) + res .group (2 )
40+ state_dict [new_key ] = state_dict [key ]
41+ del state_dict [key ]
42+ model .load_state_dict (state_dict )
2943 return model
3044
3145
@@ -39,7 +53,20 @@ def densenet169(pretrained=False, **kwargs):
3953 model = DenseNet (num_init_features = 64 , growth_rate = 32 , block_config = (6 , 12 , 32 , 32 ),
4054 ** kwargs )
4155 if pretrained :
42- model .load_state_dict (model_zoo .load_url (model_urls ['densenet169' ]))
56+ # '.'s are no longer allowed in module names, but pervious _DenseLayer
57+ # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
58+ # They are also in the checkpoints in model_urls. This pattern is used
59+ # to find such keys.
60+ pattern = re .compile (
61+ r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$' )
62+ state_dict = model_zoo .load_url (model_urls ['densenet169' ])
63+ for key in list (state_dict .keys ()):
64+ res = pattern .match (key )
65+ if res :
66+ new_key = res .group (1 ) + res .group (2 )
67+ state_dict [new_key ] = state_dict [key ]
68+ del state_dict [key ]
69+ model .load_state_dict (state_dict )
4370 return model
4471
4572
@@ -53,7 +80,20 @@ def densenet201(pretrained=False, **kwargs):
5380 model = DenseNet (num_init_features = 64 , growth_rate = 32 , block_config = (6 , 12 , 48 , 32 ),
5481 ** kwargs )
5582 if pretrained :
56- model .load_state_dict (model_zoo .load_url (model_urls ['densenet201' ]))
83+ # '.'s are no longer allowed in module names, but pervious _DenseLayer
84+ # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
85+ # They are also in the checkpoints in model_urls. This pattern is used
86+ # to find such keys.
87+ pattern = re .compile (
88+ r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$' )
89+ state_dict = model_zoo .load_url (model_urls ['densenet201' ])
90+ for key in list (state_dict .keys ()):
91+ res = pattern .match (key )
92+ if res :
93+ new_key = res .group (1 ) + res .group (2 )
94+ state_dict [new_key ] = state_dict [key ]
95+ del state_dict [key ]
96+ model .load_state_dict (state_dict )
5797 return model
5898
5999
@@ -67,20 +107,33 @@ def densenet161(pretrained=False, **kwargs):
67107 model = DenseNet (num_init_features = 96 , growth_rate = 48 , block_config = (6 , 12 , 36 , 24 ),
68108 ** kwargs )
69109 if pretrained :
70- model .load_state_dict (model_zoo .load_url (model_urls ['densenet161' ]))
110+ # '.'s are no longer allowed in module names, but pervious _DenseLayer
111+ # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
112+ # They are also in the checkpoints in model_urls. This pattern is used
113+ # to find such keys.
114+ pattern = re .compile (
115+ r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$' )
116+ state_dict = model_zoo .load_url (model_urls ['densenet161' ])
117+ for key in list (state_dict .keys ()):
118+ res = pattern .match (key )
119+ if res :
120+ new_key = res .group (1 ) + res .group (2 )
121+ state_dict [new_key ] = state_dict [key ]
122+ del state_dict [key ]
123+ model .load_state_dict (state_dict )
71124 return model
72125
73126
74127class _DenseLayer (nn .Sequential ):
75128 def __init__ (self , num_input_features , growth_rate , bn_size , drop_rate ):
76129 super (_DenseLayer , self ).__init__ ()
77- self .add_module ('norm.1 ' , nn .BatchNorm2d (num_input_features )),
78- self .add_module ('relu.1 ' , nn .ReLU (inplace = True )),
79- self .add_module ('conv.1 ' , nn .Conv2d (num_input_features , bn_size *
130+ self .add_module ('norm1 ' , nn .BatchNorm2d (num_input_features )),
131+ self .add_module ('relu1 ' , nn .ReLU (inplace = True )),
132+ self .add_module ('conv1 ' , nn .Conv2d (num_input_features , bn_size *
80133 growth_rate , kernel_size = 1 , stride = 1 , bias = False )),
81- self .add_module ('norm.2 ' , nn .BatchNorm2d (bn_size * growth_rate )),
82- self .add_module ('relu.2 ' , nn .ReLU (inplace = True )),
83- self .add_module ('conv.2 ' , nn .Conv2d (bn_size * growth_rate , growth_rate ,
134+ self .add_module ('norm2 ' , nn .BatchNorm2d (bn_size * growth_rate )),
135+ self .add_module ('relu2 ' , nn .ReLU (inplace = True )),
136+ self .add_module ('conv2 ' , nn .Conv2d (bn_size * growth_rate , growth_rate ,
84137 kernel_size = 3 , stride = 1 , padding = 1 , bias = False )),
85138 self .drop_rate = drop_rate
86139
0 commit comments