44
44
Networks <https://colah.github.io/posts/2015-08-Understanding-LSTMs/>`__
45
45
is about LSTMs specifically but also informative about RNNs in
46
46
general
47
+ """
47
48
48
- Preparing the Data
49
- ==================
50
-
51
- .. note::
52
- Download the data from
53
- `here <https://download.pytorch.org/tutorial/data.zip>`_
54
- and extract it to the current directory.
55
-
56
- Included in the ``data/names`` directory are 18 text files named as
57
- ``[Language].txt``. Each file contains a bunch of names, one name per
58
- line, mostly romanized (but we still need to convert from Unicode to
59
- ASCII).
49
+ ######################################################################
50
+ # Preparing Torch
51
+ # ==========================
52
+ #
53
+ # Set up torch to default to the right device use GPU acceleration depending on your hardware (CPU or CUDA).
54
+ #
60
55
61
- The first thing we need to define is our data items. In this case, we will create a class called NameData
62
- which will have an __init__ function to specify the input fields and some helper functions. Our first
63
- helper function will be __str__ to convert objects to strings for easy printing
56
+ import torch
64
57
58
+ # Check if CUDA is available
59
+ device = torch .device ('cpu' )
60
+ if torch .cuda .is_available ():
61
+ device = torch .device ('cuda' )
65
62
66
- There are two key pieces of this that we will flesh out over the course of this tutorial. First is the basic data
67
- object which a label and some text. In this instance, label = the country of origin and text = the name.
63
+ torch . set_default_device ( device )
64
+ print ( f"Using device = { torch . get_default_device () } " )
68
65
69
- However, our data has some issues that we will need to clean up. First off, we need to convert Unicode to plain ASCII to
70
- limit the RNN input layers. This is accomplished by converting Unicode strings to ASCII and allowing a small set of allowed characters (allowed_characters)
71
- """
66
+ ######################################################################
67
+ # Preparing the Data
68
+ # ==================
69
+ #
70
+ # Download the data from `here <https://download.pytorch.org/tutorial/data.zip>`__
71
+ # and extract it to the current directory.
72
+ #
73
+ # Included in the ``data/names`` directory are 18 text files named as
74
+ # ``[Language].txt``. Each file contains a bunch of names, one name per
75
+ # line, mostly romanized (but we still need to convert from Unicode to
76
+ # ASCII).
77
+ #
78
+ # The first thing we need to define is our data items. In this case, we will create a class called NameData
79
+ # which will have an __init__ function to specify the input fields and some helper functions. Our first
80
+ # helper function will be __str__ to convert objects to strings for easy printing
81
+ #
82
+ # There are two key pieces of this that we will flesh out over the course of this tutorial. First is the basic data
83
+ # object which a label and some text. In this instance, label = the country of origin and text = the name.
84
+ #
85
+ # However, our data has some issues that we will need to clean up. First off, we need to convert Unicode to plain ASCII to
86
+ # limit the RNN input layers. This is accomplished by converting Unicode strings to ASCII and allowing a small set of allowed characters (allowed_characters)
72
87
73
- import torch
74
88
import string
75
89
import unicodedata
76
90
@@ -102,7 +116,7 @@ def unicodeToAscii(s):
102
116
103
117
######################################################################
104
118
# Turning Names into Tensors
105
- # --------------------------
119
+ # ==========================
106
120
#
107
121
# Now that we have all the names organized, we need to turn them into
108
122
# Tensors to make any use of them.
@@ -119,7 +133,6 @@ def unicodeToAscii(s):
119
133
#
120
134
# For this, you'll need to add a couple of capabilities to our NameData object.
121
135
122
- import torch
123
136
import string
124
137
import unicodedata
125
138
@@ -157,18 +170,18 @@ def lineToTensor(line):
157
170
return tensor
158
171
159
172
#########################
160
- #Here are some examples of how to use the NameData object
173
+ # Here are some examples of how to use the NameData object
161
174
162
175
print (f"{ NameData (label = 'none' , text = 'a' )} " )
163
176
print (f"{ NameData (label = 'Korean' , text = 'Ahn' )} " )
164
177
165
178
#########################
166
- #Congratulations, you have built the foundational tensor objects for this learning task! You can use a similar approach
167
- #for other RNN tasks with text.
179
+ # Congratulations, you have built the foundational tensor objects for this learning task! You can use a similar approach
180
+ # for other RNN tasks with text.
168
181
#
169
- #Next, we need to combine all our examples into a dataset so we can train, text and validate our models. For this,
170
- #we will use the `Dataset and DataLoader <https://pytorch.org/tutorials/beginner/basics/data_tutorial.html>` classes
171
- #to hold our dataset. Each Dataset needs to implement three functions: __init__, __len__, and __getitem__.
182
+ # Next, we need to combine all our examples into a dataset so we can train, text and validate our models. For this,
183
+ # we will use the `Dataset and DataLoader <https://pytorch.org/tutorials/beginner/basics/data_tutorial.html>` classes
184
+ # to hold our dataset. Each Dataset needs to implement three functions: __init__, __len__, and __getitem__.
172
185
173
186
from io import open
174
187
import glob
@@ -219,9 +232,10 @@ def __getitem__(self, idx):
219
232
220
233
#########################
221
234
#Using the dataset object allows us to easily split the data into train and test sets. Here we create a 80/20
222
- #split but the torch.utils.data has more useful utilities.
235
+ #split but the torch.utils.data has more useful utilities. Here we specify a generator since we need to use the
236
+ #same device as torch defaults to above.
223
237
224
- train_set , test_set = torch .utils .data .random_split (alldata , [.8 , .2 ])
238
+ train_set , test_set = torch .utils .data .random_split (alldata , [.8 , .2 ], generator = torch . Generator ( device = device ). manual_seed ( 1 ) )
225
239
226
240
print (f"train examples = { len (train_set )} , validation examples = { len (test_set )} " )
227
241
@@ -448,7 +462,10 @@ def learn(self, training_data, n_epoch = 250, n_batch_size = 64, report_every =
448
462
n_hidden = 128
449
463
hidden = torch .zeros (1 , n_hidden )
450
464
rnn = RNN (NameData .n_letters , n_hidden , alldata .labels )
451
- all_losses = rnn .learn (train_set )
465
+ start = time .time ()
466
+ all_losses = rnn .learn (train_set , n_epoch = 10 , learning_rate = 0.2 , report_every = 1 )
467
+ end = time .time ()
468
+ print (f"training took { end - start } s" )
452
469
453
470
######################################################################
454
471
# Plotting the Results
@@ -495,7 +512,7 @@ def evaluate(rnn, testing_data):
495
512
# Set up plot
496
513
fig = plt .figure ()
497
514
ax = fig .add_subplot (111 )
498
- cax = ax .matshow (confusion .numpy ())
515
+ cax = ax .matshow (confusion .cpu (). numpy ()) #numpy uses cpu here so we need to use a cpu version
499
516
fig .colorbar (cax )
500
517
501
518
# Set up axes
@@ -525,16 +542,16 @@ def evaluate(rnn, testing_data):
525
542
# Exercises
526
543
# =========
527
544
#
528
- # - Try with a different dataset of line -> label, for example:
529
- #
530
- # - Any word -> language
531
- # - First name -> gender
532
- # - Character name -> writer
533
- # - Page title -> blog or subreddit
534
- #
535
545
# - Get better results with a bigger and/or better shaped network
536
546
#
547
+ # - Vary the hyperparameters to improve performance (e.g. 250 epochs, batch size, learning rate )
537
548
# - Add more linear layers
538
549
# - Try the ``nn.LSTM`` and ``nn.GRU`` layers
539
550
# - Combine multiple of these RNNs as a higher level network
540
- #
551
+ #
552
+ # - Try with a different dataset of line -> label, for example:
553
+ #
554
+ # - Any word -> language
555
+ # - First name -> gender
556
+ # - Character name -> writer
557
+ # - Page title -> blog or subreddit
0 commit comments