Skip to content

Commit b4db18d

Browse files
committed
Simplify RNN class (e.g. one forward function), adding minibatches + optimizer
1 parent 126f7ad commit b4db18d

File tree

1 file changed

+68
-101
lines changed

1 file changed

+68
-101
lines changed

intermediate_source/char_rnn_classification_tutorial.py

Lines changed: 68 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def __init__(self, data_dir):
194194
for filename in text_files:
195195
label = os.path.splitext(os.path.basename(filename))[0]
196196
labels_set.add(label)
197-
lines = NamesDataset.readLines(filename)
197+
lines = open(filename, encoding='utf-8').read().strip().split('\n')
198198
for name in lines:
199199
self.data.append(NameData(label=label, text=name))
200200

@@ -208,11 +208,6 @@ def __getitem__(self, idx):
208208
label_tensor = torch.tensor([self.labels.index(data_item.label)], dtype=torch.long)
209209
return label_tensor, data_item.tensor, data_item.label, data_item.text
210210

211-
# Read a file and split into lines
212-
def readLines(filename):
213-
lines = open(filename, encoding='utf-8').read().strip().split('\n')
214-
return lines
215-
216211

217212
#########################
218213
#Here are some examples of how to use the NamesDataset object
@@ -265,9 +260,6 @@ def __init__(self, input_size, hidden_size, output_labels):
265260
self.h2o = nn.Linear(hidden_size, len(output_labels))
266261
self.softmax = nn.LogSoftmax(dim=1)
267262

268-
def initHidden(self):
269-
return torch.zeros(1, self.hidden_size)
270-
271263
def forward(self, input, hidden):
272264
hidden = F.tanh(self.i2h(input) + self.h2h(hidden))
273265
output = self.h2o(hidden)
@@ -285,19 +277,18 @@ def forward(self, input, hidden):
285277
######################################################################
286278
# To run a step of this network we need to pass a single character input
287279
# and a hidden state (which we initialize as zeros at first). We'll get to
288-
# multi-character names during training
280+
# multi-character names next
289281

290282
input = NameData(label='none', text='A').tensor
291-
hidden = torch.zeros(1, n_hidden)
292-
output, next_hidden = rnn(input[0], hidden)
283+
output, next_hidden = rnn(input[0], torch.zeros(1, n_hidden))
293284
print(output)
294285

295286
######################################################################
296287
# Scoring Multi-character names
297288
# --------------------
298289
# Multi-character names require just a little bit more effort which is
299290
# keeping track of the hidden output and passing it back into the RNN.
300-
# You can see this defined in the function forward_multi()
291+
# You can see this updated work defined in the function forward()
301292

302293
import torch.nn as nn
303294
import torch.nn.functional as F
@@ -313,39 +304,34 @@ def __init__(self, input_size, hidden_size, output_labels):
313304
self.h2h = nn.Linear(hidden_size, hidden_size)
314305
self.h2o = nn.Linear(hidden_size, len(output_labels))
315306
self.softmax = nn.LogSoftmax(dim=1)
316-
317-
def initHidden(self):
318-
return torch.zeros(1, self.hidden_size)
319-
320-
def forward(self, input, hidden):
321-
hidden = F.tanh(self.i2h(input) + self.h2h(hidden))
322-
output = self.h2o(hidden)
323-
output = self.softmax(output)
324-
return output, hidden
325307

326-
def forward_multi(self, line_tensor):
327-
hidden = rnn.initHidden()
308+
def forward(self, line_tensor):
309+
hidden = torch.zeros(1, rnn.hidden_size)
310+
output = torch.zeros(1, len(self.output_labels))
328311

329312
for i in range(line_tensor.size()[0]):
330-
output, hidden = rnn.forward(line_tensor[i], hidden)
313+
input = line_tensor[i]
314+
hidden = F.tanh(self.i2h(input) + self.h2h(hidden))
315+
output = self.h2o(hidden)
316+
output = self.softmax(output)
331317

332-
return output, hidden
318+
return output
333319

334320
def label_from_output(self, output):
335321
top_n, top_i = output.topk(1)
336322
label_i = top_i[0].item()
337323
return self.output_labels[label_i], label_i
338324

325+
339326
###########################
340327
#Now we can score the output for names!
341328

342329

343330
n_hidden = 128
344-
hidden = torch.zeros(1, n_hidden)
345331
rnn = RNN(NameData.n_letters, n_hidden, alldata.labels)
346332

347333
input = NameData(label='none', text='Albert').tensor
348-
output, next_hidden = rnn.forward_multi(input)
334+
output = rnn(input) #this is equivalent to output = rnn.forward(input)
349335
print(output)
350336
print(rnn.label_from_output(output))
351337

@@ -375,15 +361,15 @@ def label_from_output(self, output):
375361
# - Back-propagate
376362
# - Return the output and loss
377363
#
378-
# We also define a learn_batch() function which trains on a given dataset
379-
364+
# We also define a learn() function which trains on a given dataset with minibatches
380365

381366
import torch.nn as nn
382367
import torch.nn.functional as F
383368
import random
369+
import numpy as np
384370

385371
class RNN(nn.Module):
386-
def __init__(self, input_size, hidden_size, output_labels, criterion = nn.NLLLoss()):
372+
def __init__(self, input_size, hidden_size, output_labels):
387373
super(RNN, self).__init__()
388374

389375
self.hidden_size = hidden_size
@@ -393,97 +379,76 @@ def __init__(self, input_size, hidden_size, output_labels, criterion = nn.NLLLos
393379
self.h2h = nn.Linear(hidden_size, hidden_size)
394380
self.h2o = nn.Linear(hidden_size, len(output_labels))
395381
self.softmax = nn.LogSoftmax(dim=1)
396-
397-
self.criterion = criterion
398-
399-
def initHidden(self):
400-
return torch.zeros(1, self.hidden_size)
401-
402-
def forward(self, input, hidden):
403-
hidden = F.tanh(self.i2h(input) + self.h2h(hidden))
404-
output = self.h2o(hidden)
405-
output = self.softmax(output)
406-
return output, hidden
407382

408-
def forward_multi(self, line_tensor):
409-
hidden = self.initHidden()
383+
def forward(self, line_tensor):
384+
hidden = torch.zeros(1, rnn.hidden_size)
385+
output = torch.zeros(1, len(self.output_labels))
410386

411387
for i in range(line_tensor.size()[0]):
412-
output, hidden = self.forward(line_tensor[i], hidden)
388+
input = line_tensor[i]
389+
hidden = F.tanh(self.i2h(input) + self.h2h(hidden))
390+
output = self.h2o(hidden)
391+
output = self.softmax(output)
413392

414-
return output, hidden
393+
return output
415394

416395
def label_from_output(self, output):
417396
top_n, top_i = output.topk(1)
418397
label_i = top_i[0].item()
419-
return self.output_labels[label_i], label_i
398+
return self.output_labels[label_i], label_i
420399

421-
def learn_single(self, label_tensor, line_tensor, learning_rate = 0.005):
422-
#Train the RNN for one example with a learning rate that defaults to 0.005.
423-
424-
425-
rnn.zero_grad()
426-
output, hidden = self.forward_multi(line_tensor)
427-
428-
loss = self.criterion(output, label_tensor)
429-
loss.backward()
430-
431-
# Add parameters' gradients to their values, multiplied by learning rate
432-
for p in self.parameters():
433-
p.data.add_(p.grad.data, alpha=-learning_rate)
434-
435-
return output, loss.item()
436-
437-
def learn_batch(self, training_data, n_iters = 1000, report_every = 100):
400+
def learn(self, training_data, n_epoch = 1000, n_batch_size = 64, report_every = 50, learning_rate = 0.005, criterion = nn.NLLLoss()):
438401
"""
439402
Learn on a batch of training_data for a specified number of iterations and reporting thresholds
440403
"""
441-
442404
# Keep track of losses for plotting
443405
current_loss = 0
444406
all_losses = []
407+
self.train()
408+
optimizer = torch.optim.SGD(self.parameters(), lr=learning_rate)
445409

446410
start = time.time()
447-
print(f"training data = {training_data}")
448-
print(f"size = {len(training_data)}")
449-
450-
for iter in range(1, n_iters + 1):
451-
rand_idx = random.randint(0,len(training_data)-1)
452-
(label_tensor, text_tensor, label, text) = training_data[rand_idx]
453-
454-
output, loss = self.learn_single(label_tensor, text_tensor)
455-
current_loss += loss
456-
457-
# Print ``iter`` number, loss, name and guess
411+
print(f"training on data set with n = {len(training_data)}")
412+
413+
for iter in range(1, n_epoch + 1):
414+
self.zero_grad() # clear the gradients
415+
416+
# create some minibatches
417+
# we cannot use dataloaders because each of our names is a different length
418+
batches = list(range(len(training_data)))
419+
random.shuffle(batches)
420+
batches = np.array_split(batches, len(batches) //n_batch_size )
421+
422+
for idx, batch in enumerate(batches):
423+
batch_loss = 0
424+
for i in batch: #for each example in this batch
425+
(label_tensor, text_tensor, label, text) = training_data[i]
426+
output = self.forward(text_tensor)
427+
loss = criterion(output, label_tensor)
428+
batch_loss += loss
429+
430+
# optimize parameters
431+
batch_loss.backward()
432+
nn.utils.clip_grad_norm_(self.parameters(), 3)
433+
optimizer.step()
434+
optimizer.zero_grad()
435+
436+
current_loss += batch_loss.item() / len(batch)
437+
438+
all_losses.append(current_loss / len(batches) )
458439
if iter % report_every == 0:
459-
all_losses.append(current_loss / report_every)
460-
print(f"{iter} ({iter / n_iters:.0%}): \t iteration loss = {all_losses[-1]}")
461-
current_loss = 0
440+
print(f"{iter} ({iter / n_epoch:.0%}): \t average batch loss = {all_losses[-1]}")
441+
current_loss = 0
462442

463443
return all_losses
464444

465-
###########################
466-
#We can test this with one of our examples and see the output vector, loss and guess of a class from a random network.
467-
#
468-
#Here is a single input example
445+
##########################################################################
446+
# We can now train a dataset with mini batches for a specified number of epochs
469447

470448
n_hidden = 128
471449
hidden = torch.zeros(1, n_hidden)
472450
rnn = RNN(NameData.n_letters, n_hidden, alldata.labels)
473-
474-
(label_tensor, text_tensor, label, text) = train_set[0]
475-
print(f"training on name = {text} with label = {label}")
476-
(output, loss) = rnn.learn_single(label_tensor, text_tensor)
477-
478-
print("LogSoftmax outputs (highest score is predicted class")
479-
for i in range(len(output[0])):
480-
print (f"\t{i}. {alldata.labels[i]} => {output[0][i]}")
481-
482-
###########################
483-
#We can also train on our training data set by randomly selecting examples
484-
485-
486-
all_losses = rnn.learn_batch(train_set, n_iters=200000, report_every=10000)
451+
all_losses = rnn.learn(train_set)
487452

488453
######################################################################
489454
# Plotting the Results
@@ -513,11 +478,12 @@ def learn_batch(self, training_data, n_iters = 1000, report_every = 100):
513478

514479
def evaluate(rnn, testing_data):
515480
confusion = torch.zeros(len(rnn.output_labels), len(rnn.output_labels))
516-
517-
with torch.no_grad(): # do not record the gradiants during eval phase
481+
482+
rnn.eval() #set to eval mode
483+
with torch.no_grad(): # do not record the gradiants during eval phase
518484
for i in range(len(testing_data)):
519485
(label_tensor, text_tensor, label, text) = testing_data[i]
520-
(output, hidden) = rnn.forward_multi(text_tensor)
486+
output = rnn.forward(text_tensor)
521487
guess, guess_i = rnn.label_from_output(output)
522488
label_i = rnn.output_labels.index(label)
523489
confusion[label_i][guess_i] += 1
@@ -541,7 +507,8 @@ def evaluate(rnn, testing_data):
541507
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
542508

543509
# sphinx_gallery_thumbnail_number = 2
544-
plt.show()
510+
plt.show()
511+
545512

546513
evaluate(rnn, test_set)
547514

0 commit comments

Comments
 (0)