Skip to content

Commit cbb05c5

Browse files
Maratyszczafmassa
authored andcommitted
Use torch.nn.init in SqueezeNet models (#146)
1 parent 874481f commit cbb05c5

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

torchvision/models/squeezenet.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import math
22
import torch
33
import torch.nn as nn
4+
import torch.nn.init as init
45
import torch.utils.model_zoo as model_zoo
56

67

@@ -87,13 +88,10 @@ def __init__(self, version=1.0, num_classes=1000):
8788

8889
for m in self.modules():
8990
if isinstance(m, nn.Conv2d):
90-
gain = 2.0
9191
if m is final_conv:
92-
m.weight.data.normal_(0, 0.01)
92+
init.normal(m.weight.data, mean=0.0, std=0.01)
9393
else:
94-
fan_in = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
95-
u = math.sqrt(3.0 * gain / fan_in)
96-
m.weight.data.uniform_(-u, u)
94+
init.kaiming_uniform(m.weight.data)
9795
if m.bias is not None:
9896
m.bias.data.zero_()
9997

0 commit comments

Comments
 (0)