|
97 | 97 | },
|
98 | 98 | "outputs": [],
|
99 | 99 | "source": [
|
100 |
| - "import torch \n", |
| 100 | + "import torch\n", |
101 | 101 | "\n",
|
102 | 102 | "# Check if CUDA is available\n",
|
103 | 103 | "device = torch.device('cpu')\n",
|
|
138 | 138 | },
|
139 | 139 | "outputs": [],
|
140 | 140 | "source": [
|
141 |
| - "import string \n", |
| 141 | + "import string\n", |
142 | 142 | "import unicodedata\n",
|
143 | 143 | "\n",
|
144 |
| - "allowed_characters = string.ascii_letters + \" .,;'\"\n", |
145 |
| - "n_letters = len(allowed_characters) \n", |
| 144 | + "# We can use \"_\" to represent an out-of-vocabulary character, that is, any character we are not handling in our model\n", |
| 145 | + "allowed_characters = string.ascii_letters + \" .,;'\" + \"_\"\n", |
| 146 | + "n_letters = len(allowed_characters)\n", |
146 | 147 | "\n",
|
147 |
| - "# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427 \n", |
| 148 | + "# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427\n", |
148 | 149 | "def unicodeToAscii(s):\n",
|
149 | 150 | " return ''.join(\n",
|
150 | 151 | " c for c in unicodedata.normalize('NFD', s)\n",
|
|
203 | 204 | "source": [
|
204 | 205 | "# Find letter index from all_letters, e.g. \"a\" = 0\n",
|
205 | 206 | "def letterToIndex(letter):\n",
|
206 |
| - " return allowed_characters.find(letter)\n", |
| 207 | + " # return our out-of-vocabulary character if we encounter a letter unknown to our model\n", |
| 208 | + " if letter not in allowed_characters:\n", |
| 209 | + " return allowed_characters.find(\"_\")\n", |
| 210 | + " else:\n", |
| 211 | + " return allowed_characters.find(letter)\n", |
207 | 212 | "\n",
|
208 | 213 | "# Turn a line into a <line_length x 1 x n_letters>,\n",
|
209 | 214 | "# or an array of one-hot letter vectors\n",
|
|
261 | 266 | "from io import open\n",
|
262 | 267 | "import glob\n",
|
263 | 268 | "import os\n",
|
264 |
| - "import time \n", |
| 269 | + "import time\n", |
265 | 270 | "\n",
|
266 | 271 | "import torch\n",
|
267 | 272 | "from torch.utils.data import Dataset\n",
|
|
270 | 275 | "\n",
|
271 | 276 | " def __init__(self, data_dir):\n",
|
272 | 277 | " self.data_dir = data_dir #for provenance of the dataset\n",
|
273 |
| - " self.load_time = time.localtime #for provenance of the dataset \n", |
| 278 | + " self.load_time = time.localtime #for provenance of the dataset\n", |
274 | 279 | " labels_set = set() #set of all classes\n",
|
275 | 280 | "\n",
|
276 | 281 | " self.data = []\n",
|
277 | 282 | " self.data_tensors = []\n",
|
278 |
| - " self.labels = [] \n", |
279 |
| - " self.labels_tensors = [] \n", |
| 283 | + " self.labels = []\n", |
| 284 | + " self.labels_tensors = []\n", |
280 | 285 | "\n",
|
281 | 286 | " #read all the ``.txt`` files in the specified directory\n",
|
282 |
| - " text_files = glob.glob(os.path.join(data_dir, '*.txt')) \n", |
| 287 | + " text_files = glob.glob(os.path.join(data_dir, '*.txt'))\n", |
283 | 288 | " for filename in text_files:\n",
|
284 | 289 | " label = os.path.splitext(os.path.basename(filename))[0]\n",
|
285 | 290 | " labels_set.add(label)\n",
|
286 | 291 | " lines = open(filename, encoding='utf-8').read().strip().split('\\n')\n",
|
287 |
| - " for name in lines: \n", |
| 292 | + " for name in lines:\n", |
288 | 293 | " self.data.append(name)\n",
|
289 | 294 | " self.data_tensors.append(lineToTensor(name))\n",
|
290 | 295 | " self.labels.append(label)\n",
|
291 | 296 | "\n",
|
292 |
| - " #Cache the tensor representation of the labels \n", |
| 297 | + " #Cache the tensor representation of the labels\n", |
293 | 298 | " self.labels_uniq = list(labels_set)\n",
|
294 | 299 | " for idx in range(len(self.labels)):\n",
|
295 | 300 | " temp_tensor = torch.tensor([self.labels_uniq.index(self.labels[idx])], dtype=torch.long)\n",
|
|
302 | 307 | " data_item = self.data[idx]\n",
|
303 | 308 | " data_label = self.labels[idx]\n",
|
304 | 309 | " data_tensor = self.data_tensors[idx]\n",
|
305 |
| - " label_tensor = self.labels_tensors[idx] \n", |
| 310 | + " label_tensor = self.labels_tensors[idx]\n", |
306 | 311 | "\n",
|
307 | 312 | " return label_tensor, data_tensor, data_label, data_item"
|
308 | 313 | ]
|
|
402 | 407 | " self.rnn = nn.RNN(input_size, hidden_size)\n",
|
403 | 408 | " self.h2o = nn.Linear(hidden_size, output_size)\n",
|
404 | 409 | " self.softmax = nn.LogSoftmax(dim=1)\n",
|
405 |
| - " \n", |
| 410 | + "\n", |
406 | 411 | " def forward(self, line_tensor):\n",
|
407 | 412 | " rnn_out, hidden = self.rnn(line_tensor)\n",
|
408 | 413 | " output = self.h2o(hidden[0])\n",
|
|
415 | 420 | "cell_type": "markdown",
|
416 | 421 | "metadata": {},
|
417 | 422 | "source": [
|
418 |
| - "We can then create an RNN with 57 input nodes, 128 hidden nodes, and 18\n", |
| 423 | + "We can then create an RNN with 58 input nodes, 128 hidden nodes, and 18\n", |
419 | 424 | "outputs:\n"
|
420 | 425 | ]
|
421 | 426 | },
|
|
456 | 461 | "\n",
|
457 | 462 | "input = lineToTensor('Albert')\n",
|
458 | 463 | "output = rnn(input) #this is equivalent to ``output = rnn.forward(input)``\n",
|
459 |
| - "print(output) \n", |
| 464 | + "print(output)\n", |
460 | 465 | "print(label_from_output(output, alldata.labels_uniq))"
|
461 | 466 | ]
|
462 | 467 | },
|
|
494 | 499 | },
|
495 | 500 | "outputs": [],
|
496 | 501 | "source": [
|
497 |
| - "import random \n", |
498 |
| - "import numpy as np \n", |
| 502 | + "import random\n", |
| 503 | + "import numpy as np\n", |
499 | 504 | "\n",
|
500 | 505 | "def train(rnn, training_data, n_epoch = 10, n_batch_size = 64, report_every = 50, learning_rate = 0.2, criterion = nn.NLLLoss()):\n",
|
501 | 506 | " \"\"\"\n",
|
|
504 | 509 | " # Keep track of losses for plotting\n",
|
505 | 510 | " current_loss = 0\n",
|
506 | 511 | " all_losses = []\n",
|
507 |
| - " rnn.train() \n", |
| 512 | + " rnn.train()\n", |
508 | 513 | " optimizer = torch.optim.SGD(rnn.parameters(), lr=learning_rate)\n",
|
509 | 514 | "\n",
|
510 | 515 | " start = time.time()\n",
|
511 | 516 | " print(f\"training on data set with n = {len(training_data)}\")\n",
|
512 | 517 | "\n",
|
513 |
| - " for iter in range(1, n_epoch + 1): \n", |
514 |
| - " rnn.zero_grad() # clear the gradients \n", |
| 518 | + " for iter in range(1, n_epoch + 1):\n", |
| 519 | + " rnn.zero_grad() # clear the gradients\n", |
515 | 520 | "\n",
|
516 | 521 | " # create some minibatches\n",
|
517 | 522 | " # we cannot use dataloaders because each of our names is a different length\n",
|
518 | 523 | " batches = list(range(len(training_data)))\n",
|
519 | 524 | " random.shuffle(batches)\n",
|
520 | 525 | " batches = np.array_split(batches, len(batches) //n_batch_size )\n",
|
521 | 526 | "\n",
|
522 |
| - " for idx, batch in enumerate(batches): \n", |
| 527 | + " for idx, batch in enumerate(batches):\n", |
523 | 528 | " batch_loss = 0\n",
|
524 | 529 | " for i in batch: #for each example in this batch\n",
|
525 | 530 | " (label_tensor, text_tensor, label, text) = training_data[i]\n",
|
|
534 | 539 | " optimizer.zero_grad()\n",
|
535 | 540 | "\n",
|
536 | 541 | " current_loss += batch_loss.item() / len(batch)\n",
|
537 |
| - " \n", |
| 542 | + "\n", |
538 | 543 | " all_losses.append(current_loss / len(batches) )\n",
|
539 | 544 | " if iter % report_every == 0:\n",
|
540 | 545 | " print(f\"{iter} ({iter / n_epoch:.0%}): \\t average batch loss = {all_losses[-1]}\")\n",
|
541 | 546 | " current_loss = 0\n",
|
542 |
| - " \n", |
| 547 | + "\n", |
543 | 548 | " return all_losses"
|
544 | 549 | ]
|
545 | 550 | },
|
|
617 | 622 | "source": [
|
618 | 623 | "def evaluate(rnn, testing_data, classes):\n",
|
619 | 624 | " confusion = torch.zeros(len(classes), len(classes))\n",
|
620 |
| - " \n", |
| 625 | + "\n", |
621 | 626 | " rnn.eval() #set to eval mode\n",
|
622 | 627 | " with torch.no_grad(): # do not record the gradients during eval phase\n",
|
623 | 628 | " for i in range(len(testing_data)):\n",
|
624 | 629 | " (label_tensor, text_tensor, label, text) = testing_data[i]\n",
|
625 |
| - " output = rnn(text_tensor) \n", |
| 630 | + " output = rnn(text_tensor)\n", |
626 | 631 | " guess, guess_i = label_from_output(output, classes)\n",
|
627 | 632 | " label_i = classes.index(label)\n",
|
628 | 633 | " confusion[label_i][guess_i] += 1\n",
|
|
0 commit comments