Skip to content

Commit 3d2bb57

Browse files
committed
fixing training description to focus on the single train() function rather than building it up
1 parent cf3de68 commit 3d2bb57

File tree

1 file changed

+5
-15
lines changed

1 file changed

+5
-15
lines changed

intermediate_source/char_rnn_classification_tutorial.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -287,26 +287,16 @@ def label_from_output(output, output_labels):
287287
#
288288
# Now all it takes to train this network is show it a bunch of examples,
289289
# have it make guesses, and tell it if it's wrong.
290-
#
291-
# We start by defining a function learn_single() which learns from a single
292-
# piece of input data.
293-
#
294-
# - Create input and target tensors
295-
# - Create a zeroed initial hidden state
296-
# - Read each letter in and
297-
#
298-
# - Keep hidden state for next letter
299-
#
300-
# - Compare final output to target
301-
# - Back-propagate
302-
# - Return the output and loss
303290
#
304-
# We do this by defining a learn() function which trains on a given dataset with minibatches
291+
# We do this by defining a train() function which trains on a given dataset with minibatches. RNNs
292+
# train similar to other networks so for completeness we include a batched training method here.
293+
# The loop (for i in batch) computes the losses for each of the items in the batch before adjusting the
294+
# weights. This is repeated until the number of epochs is reached.
305295

306296
import random
307297
import numpy as np
308298

309-
def train(rnn, training_data, n_epoch = 250, n_batch_size = 64, report_every = 50, learning_rate = 0.005, criterion = nn.NLLLoss()):
299+
def train(rnn, training_data, n_epoch = 10, n_batch_size = 64, report_every = 50, learning_rate = 0.2, criterion = nn.NLLLoss()):
310300
"""
311301
Learn on a batch of training_data for a specified number of iterations and reporting thresholds
312302
"""

0 commit comments

Comments
 (0)