Skip to content

Commit c900f66

Browse files
committed
modify: VGG tutorial
train loader fail
1 parent f706a64 commit c900f66

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

beginner_source/Pretraining_Vgg_from_scratch.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
``Pretraining`` VGG from scratch
2+
Pretraining VGG from scratch
33
============================
44
55
@@ -29,7 +29,12 @@
2929
* Complete the `Learn the Basics tutorials <https://pytorch.org/tutorials/beginner/basics/intro.html>`__
3030
* Familiarity with basic machine learning concepts and terms
3131
32-
If you are running this in Google Colab, install albumentations
32+
If you are running this in Google Colab, install ``albumentations`` by running the following command:
33+
34+
.. code-block:: py
35+
36+
37+
!pip install albumentations
3338
3439
3540
"""
@@ -101,10 +106,10 @@
101106
#
102107
# Unlike ``AlexNet``'s 5x5 9x9 filters, VGG only uses 3x3 filters.
103108
# Using multiple 3x3 filters can obtain the same receptive field as using a 5x5 filter, but it is effective in reducing the number of parameters.
104-
# In addition, since it passes through multiple nonlinear functions, the nonlinearity increases even more.
109+
# In addition, since it passes through multiple nonlinear functions, the ``nonlinearity`` increases even more.
105110
#
106111
# VGG applied a max pooling layer after multiple convolutional layers to reduce the spatial size.
107-
# This allowed the feature map to be downsampled while preserving important information.
112+
# This allowed the feature map to be ``downsampled`` while preserving important information.
108113
# Thanks to this, the network could learn high-dimensional features in deeper layers and prevent overfitting.
109114

110115
######################################################################
@@ -447,7 +452,7 @@ def accuracy(output, target, topk=(1,)):
447452
# we use ``CIFAR100`` .
448453
#
449454

450-
if DatasetName == 'Cifar' :
455+
if DatasetName == 'CIFAR' :
451456
train_data = Custom_Cifar(root=os.getcwd(),download=True)
452457
val_data = Custom_Cifar(root=os.getcwd(),train=False,download=True)
453458
val_data.val= True

0 commit comments

Comments
 (0)