Skip to content

Commit 758f363

Browse files
committed
Fix pre/post-training evaluation to use same batch in nn_tutorial
The tutorial was comparing loss on different batches: - Pre-training: evaluated on first 64 instances (batch 0) - Post-training: evaluated on last batch from training loop This made the comparison misleading as it wasn't measuring improvement on the same data. Changes: - Save the initial batch (xb_initial, yb_initial) after first evaluation - Use the saved initial batch for post-training evaluation - Added clarifying comment about fair comparison - Now both evaluations use the same data (first 64 training instances) This provides an accurate before/after comparison showing the model's improvement on the same batch of data.
1 parent 4fa1fa8 commit 758f363

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

beginner_source/nn_tutorial.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,10 @@ def nll(input, target):
174174
yb = y_train[0:bs]
175175
print(loss_func(preds, yb))
176176

177+
# Save the first batch for comparison after training
178+
xb_initial = xb
179+
yb_initial = yb
180+
177181

178182
###############################################################################
179183
# Let's also implement a function to calculate the accuracy of our model.
@@ -244,9 +248,10 @@ def accuracy(out, yb):
244248
#
245249
# Let's check the loss and accuracy and compare those to what we got
246250
# earlier. We expect that the loss will have decreased and accuracy to
247-
# have increased, and they have.
251+
# have increased, and they have. We evaluate on the same initial batch
252+
# we used before training for a fair comparison.
248253

249-
print(loss_func(model(xb), yb), accuracy(model(xb), yb))
254+
print(loss_func(model(xb_initial), yb_initial), accuracy(model(xb_initial), yb_initial))
250255

251256
###############################################################################
252257
# Using ``torch.nn.functional``

0 commit comments

Comments
 (0)