Skip to content

Commit 80804ae

Browse files
committed
decreasing training time by 97% (72s on CPU) by tuning hyper parameters, adding device config for CI steps, cleaning up documentatation
1 parent 6d08a08 commit 80804ae

File tree

2 files changed

+59
-41
lines changed

2 files changed

+59
-41
lines changed

en-wordlist.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ MaskRCNN
146146
Minifier
147147
MobileNet
148148
ModelABC
149+
MPS
149150
Mypy
150151
NameData
151152
NamesDataset

intermediate_source/char_rnn_classification_tutorial.py

Lines changed: 58 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -44,33 +44,47 @@
4444
Networks <https://colah.github.io/posts/2015-08-Understanding-LSTMs/>`__
4545
is about LSTMs specifically but also informative about RNNs in
4646
general
47+
"""
4748

48-
Preparing the Data
49-
==================
50-
51-
.. note::
52-
Download the data from
53-
`here <https://download.pytorch.org/tutorial/data.zip>`_
54-
and extract it to the current directory.
55-
56-
Included in the ``data/names`` directory are 18 text files named as
57-
``[Language].txt``. Each file contains a bunch of names, one name per
58-
line, mostly romanized (but we still need to convert from Unicode to
59-
ASCII).
49+
######################################################################
50+
# Preparing Torch
51+
# ==========================
52+
#
53+
# Set up torch to default to the right device use GPU acceleration depending on your hardware (CPU or CUDA).
54+
#
6055

61-
The first thing we need to define is our data items. In this case, we will create a class called NameData
62-
which will have an __init__ function to specify the input fields and some helper functions. Our first
63-
helper function will be __str__ to convert objects to strings for easy printing
56+
import torch
6457

58+
# Check if CUDA is available
59+
device = torch.device('cpu')
60+
if torch.cuda.is_available():
61+
device = torch.device('cuda')
6562

66-
There are two key pieces of this that we will flesh out over the course of this tutorial. First is the basic data
67-
object which a label and some text. In this instance, label = the country of origin and text = the name.
63+
torch.set_default_device(device)
64+
print(f"Using device = {torch.get_default_device()}")
6865

69-
However, our data has some issues that we will need to clean up. First off, we need to convert Unicode to plain ASCII to
70-
limit the RNN input layers. This is accomplished by converting Unicode strings to ASCII and allowing a small set of allowed characters (allowed_characters)
71-
"""
66+
######################################################################
67+
# Preparing the Data
68+
# ==================
69+
#
70+
# Download the data from `here <https://download.pytorch.org/tutorial/data.zip>`__
71+
# and extract it to the current directory.
72+
#
73+
# Included in the ``data/names`` directory are 18 text files named as
74+
# ``[Language].txt``. Each file contains a bunch of names, one name per
75+
# line, mostly romanized (but we still need to convert from Unicode to
76+
# ASCII).
77+
#
78+
# The first thing we need to define is our data items. In this case, we will create a class called NameData
79+
# which will have an __init__ function to specify the input fields and some helper functions. Our first
80+
# helper function will be __str__ to convert objects to strings for easy printing
81+
#
82+
# There are two key pieces of this that we will flesh out over the course of this tutorial. First is the basic data
83+
# object which a label and some text. In this instance, label = the country of origin and text = the name.
84+
#
85+
# However, our data has some issues that we will need to clean up. First off, we need to convert Unicode to plain ASCII to
86+
# limit the RNN input layers. This is accomplished by converting Unicode strings to ASCII and allowing a small set of allowed characters (allowed_characters)
7287

73-
import torch
7488
import string
7589
import unicodedata
7690

@@ -102,7 +116,7 @@ def unicodeToAscii(s):
102116

103117
######################################################################
104118
# Turning Names into Tensors
105-
# --------------------------
119+
# ==========================
106120
#
107121
# Now that we have all the names organized, we need to turn them into
108122
# Tensors to make any use of them.
@@ -119,7 +133,6 @@ def unicodeToAscii(s):
119133
#
120134
# For this, you'll need to add a couple of capabilities to our NameData object.
121135

122-
import torch
123136
import string
124137
import unicodedata
125138

@@ -157,18 +170,18 @@ def lineToTensor(line):
157170
return tensor
158171

159172
#########################
160-
#Here are some examples of how to use the NameData object
173+
# Here are some examples of how to use the NameData object
161174

162175
print (f"{NameData(label='none', text='a')}")
163176
print (f"{NameData(label='Korean', text='Ahn')}")
164177

165178
#########################
166-
#Congratulations, you have built the foundational tensor objects for this learning task! You can use a similar approach
167-
#for other RNN tasks with text.
179+
# Congratulations, you have built the foundational tensor objects for this learning task! You can use a similar approach
180+
# for other RNN tasks with text.
168181
#
169-
#Next, we need to combine all our examples into a dataset so we can train, text and validate our models. For this,
170-
#we will use the `Dataset and DataLoader <https://pytorch.org/tutorials/beginner/basics/data_tutorial.html>` classes
171-
#to hold our dataset. Each Dataset needs to implement three functions: __init__, __len__, and __getitem__.
182+
# Next, we need to combine all our examples into a dataset so we can train, text and validate our models. For this,
183+
# we will use the `Dataset and DataLoader <https://pytorch.org/tutorials/beginner/basics/data_tutorial.html>` classes
184+
# to hold our dataset. Each Dataset needs to implement three functions: __init__, __len__, and __getitem__.
172185

173186
from io import open
174187
import glob
@@ -219,9 +232,10 @@ def __getitem__(self, idx):
219232

220233
#########################
221234
#Using the dataset object allows us to easily split the data into train and test sets. Here we create a 80/20
222-
#split but the torch.utils.data has more useful utilities.
235+
#split but the torch.utils.data has more useful utilities. Here we specify a generator since we need to use the
236+
#same device as torch defaults to above.
223237

224-
train_set, test_set = torch.utils.data.random_split(alldata, [.8, .2])
238+
train_set, test_set = torch.utils.data.random_split(alldata, [.8, .2], generator=torch.Generator(device=device).manual_seed(1))
225239

226240
print(f"train examples = {len(train_set)}, validation examples = {len(test_set)}")
227241

@@ -448,7 +462,10 @@ def learn(self, training_data, n_epoch = 250, n_batch_size = 64, report_every =
448462
n_hidden = 128
449463
hidden = torch.zeros(1, n_hidden)
450464
rnn = RNN(NameData.n_letters, n_hidden, alldata.labels)
451-
all_losses = rnn.learn(train_set)
465+
start = time.time()
466+
all_losses = rnn.learn(train_set, n_epoch=10, learning_rate=0.2, report_every=1)
467+
end = time.time()
468+
print(f"training took {end-start}s")
452469

453470
######################################################################
454471
# Plotting the Results
@@ -495,7 +512,7 @@ def evaluate(rnn, testing_data):
495512
# Set up plot
496513
fig = plt.figure()
497514
ax = fig.add_subplot(111)
498-
cax = ax.matshow(confusion.numpy())
515+
cax = ax.matshow(confusion.cpu().numpy()) #numpy uses cpu here so we need to use a cpu version
499516
fig.colorbar(cax)
500517

501518
# Set up axes
@@ -525,16 +542,16 @@ def evaluate(rnn, testing_data):
525542
# Exercises
526543
# =========
527544
#
528-
# - Try with a different dataset of line -> label, for example:
529-
#
530-
# - Any word -> language
531-
# - First name -> gender
532-
# - Character name -> writer
533-
# - Page title -> blog or subreddit
534-
#
535545
# - Get better results with a bigger and/or better shaped network
536546
#
547+
# - Vary the hyperparameters to improve performance (e.g. 250 epochs, batch size, learning rate )
537548
# - Add more linear layers
538549
# - Try the ``nn.LSTM`` and ``nn.GRU`` layers
539550
# - Combine multiple of these RNNs as a higher level network
540-
#
551+
#
552+
# - Try with a different dataset of line -> label, for example:
553+
#
554+
# - Any word -> language
555+
# - First name -> gender
556+
# - Character name -> writer
557+
# - Page title -> blog or subreddit

0 commit comments

Comments
 (0)