Skip to content

Commit 47bc815

Browse files
lopuhinsoumith
authored andcommitted
pass kwargs to densenet (#199)
1 parent 08b1f59 commit 47bc815

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

torchvision/models/densenet.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ def densenet121(pretrained=False, **kwargs):
2222
Args:
2323
pretrained (bool): If True, returns a model pre-trained on ImageNet
2424
"""
25-
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16))
25+
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
26+
**kwargs)
2627
if pretrained:
2728
model.load_state_dict(model_zoo.load_url(model_urls['densenet121']))
2829
return model
@@ -35,7 +36,8 @@ def densenet169(pretrained=False, **kwargs):
3536
Args:
3637
pretrained (bool): If True, returns a model pre-trained on ImageNet
3738
"""
38-
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32))
39+
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
40+
**kwargs)
3941
if pretrained:
4042
model.load_state_dict(model_zoo.load_url(model_urls['densenet169']))
4143
return model
@@ -48,7 +50,8 @@ def densenet201(pretrained=False, **kwargs):
4850
Args:
4951
pretrained (bool): If True, returns a model pre-trained on ImageNet
5052
"""
51-
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32))
53+
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
54+
**kwargs)
5255
if pretrained:
5356
model.load_state_dict(model_zoo.load_url(model_urls['densenet201']))
5457
return model
@@ -61,7 +64,8 @@ def densenet161(pretrained=False, **kwargs):
6164
Args:
6265
pretrained (bool): If True, returns a model pre-trained on ImageNet
6366
"""
64-
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24))
67+
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
68+
**kwargs)
6569
if pretrained:
6670
model.load_state_dict(model_zoo.load_url(model_urls['densenet161']))
6771
return model

0 commit comments

Comments
 (0)