Skip to content

Commit 1f085a0

Browse files
karandwivedi42soumith
authored andcommitted
Add num_classes (#128)
1 parent 74d04d2 commit 1f085a0

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torchvision/models/vgg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
class VGG(nn.Module):
2121

22-
def __init__(self, features):
22+
def __init__(self, features, num_classes=1000):
2323
super(VGG, self).__init__()
2424
self.features = features
2525
self.classifier = nn.Sequential(
@@ -29,7 +29,7 @@ def __init__(self, features):
2929
nn.Linear(4096, 4096),
3030
nn.ReLU(True),
3131
nn.Dropout(),
32-
nn.Linear(4096, 1000),
32+
nn.Linear(4096, num_classes),
3333
)
3434
self._initialize_weights()
3535

0 commit comments

Comments
 (0)