Skip to content

Commit 6b1caea

Browse files
committed
Improve LR scheduler documentation in transfer learning tutorial
1 parent f99e9e8 commit 6b1caea

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

beginner_source/transfer_learning_tutorial.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,12 @@ def imshow(inp, title=None):
142142
# - Scheduling the learning rate
143143
# - Saving the best model
144144
#
145-
# In the following, parameter ``scheduler`` is an LR scheduler object from
146-
# ``torch.optim.lr_scheduler``.
145+
# In this tutorial, `scheduler` is an LR scheduler object (e.g. StepLR).
146+
# For schedulers like StepLR, the recommended usage is:
147+
# optimizer.step() followed by scheduler.step()
148+
# which is why `scheduler.step()` is called once at the end of each epoch,
149+
# after all optimizer steps for that epoch are complete.
150+
147151

148152

149153
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
@@ -185,7 +189,8 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
185189
_, preds = torch.max(outputs, 1)
186190
loss = criterion(outputs, labels)
187191

188-
# backward + optimize only if in training phase
192+
# backward pass + optimizer step (only in training phase)
193+
189194
if phase == 'train':
190195
loss.backward()
191196
optimizer.step()

0 commit comments

Comments
 (0)