|
267 | 267 | "id": "95kkUdRoaeMw"
|
268 | 268 | },
|
269 | 269 | "source": [
|
270 |
| - "Next, you will use the `text_dataset_from_directory` utility to create a labeled `tf.data.Dataset`. [tf.data](https://www.tensorflow.org/guide/data) is a powerful collection of tools for working with data. \n", |
| 270 | + "Next, you will use the `text_dataset_from_directory` utility to create a labeled `tf.data.Dataset`. [tf.data](https://www.tensorflow.org/guide/data) is a powerful collection of tools for working with data.\n", |
271 | 271 | "\n",
|
272 |
| - "When running a machine learning experiment, it is a best practice to divide your dataset into three splits: [train](https://developers.google.com/machine-learning/glossary#training_set), [validation](https://developers.google.com/machine-learning/glossary#validation_set), and [test](https://developers.google.com/machine-learning/glossary#test-set). \n", |
| 272 | + "When running a machine learning experiment, it is a best practice to divide your dataset into three splits: [train](https://developers.google.com/machine-learning/glossary#training_set), [validation](https://developers.google.com/machine-learning/glossary#validation_set), and [test](https://developers.google.com/machine-learning/glossary#test-set).\n", |
273 | 273 | "\n",
|
274 | 274 | "The IMDB dataset has already been divided into train and test, but it lacks a validation set. Let's create a validation set using an 80:20 split of the training data by using the `validation_split` argument below."
|
275 | 275 | ]
|
|
286 | 286 | "seed = 42\n",
|
287 | 287 | "\n",
|
288 | 288 | "raw_train_ds = tf.keras.utils.text_dataset_from_directory(\n",
|
289 |
| - " 'aclImdb/train', \n", |
290 |
| - " batch_size=batch_size, \n", |
291 |
| - " validation_split=0.2, \n", |
292 |
| - " subset='training', \n", |
| 289 | + " 'aclImdb/train',\n", |
| 290 | + " batch_size=batch_size,\n", |
| 291 | + " validation_split=0.2,\n", |
| 292 | + " subset='training',\n", |
293 | 293 | " seed=seed)"
|
294 | 294 | ]
|
295 | 295 | },
|
|
322 | 322 | "id": "JWq1SUIrp1a-"
|
323 | 323 | },
|
324 | 324 | "source": [
|
325 |
| - "Notice the reviews contain raw text (with punctuation and occasional HTML tags like `<br/>`). You will show how to handle these in the following section. \n", |
| 325 | + "Notice the reviews contain raw text (with punctuation and occasional HTML tags like `<br/>`). You will show how to handle these in the following section.\n", |
326 | 326 | "\n",
|
327 | 327 | "The labels are 0 or 1. To see which of these correspond to positive and negative movie reviews, you can check the `class_names` property on the dataset.\n"
|
328 | 328 | ]
|
|
366 | 366 | "outputs": [],
|
367 | 367 | "source": [
|
368 | 368 | "raw_val_ds = tf.keras.utils.text_dataset_from_directory(\n",
|
369 |
| - " 'aclImdb/train', \n", |
370 |
| - " batch_size=batch_size, \n", |
371 |
| - " validation_split=0.2, \n", |
372 |
| - " subset='validation', \n", |
| 369 | + " 'aclImdb/train',\n", |
| 370 | + " batch_size=batch_size,\n", |
| 371 | + " validation_split=0.2,\n", |
| 372 | + " subset='validation',\n", |
373 | 373 | " seed=seed)"
|
374 | 374 | ]
|
375 | 375 | },
|
|
382 | 382 | "outputs": [],
|
383 | 383 | "source": [
|
384 | 384 | "raw_test_ds = tf.keras.utils.text_dataset_from_directory(\n",
|
385 |
| - " 'aclImdb/test', \n", |
| 385 | + " 'aclImdb/test',\n", |
386 | 386 | " batch_size=batch_size)"
|
387 | 387 | ]
|
388 | 388 | },
|
|
394 | 394 | "source": [
|
395 | 395 | "### Prepare the dataset for training\n",
|
396 | 396 | "\n",
|
397 |
| - "Next, you will standardize, tokenize, and vectorize the data using the helpful `tf.keras.layers.TextVectorization` layer. \n", |
| 397 | + "Next, you will standardize, tokenize, and vectorize the data using the helpful `tf.keras.layers.TextVectorization` layer.\n", |
398 | 398 | "\n",
|
399 | 399 | "Standardization refers to preprocessing the text, typically to remove punctuation or HTML elements to simplify the dataset. Tokenization refers to splitting strings into tokens (for example, splitting a sentence into individual words, by splitting on whitespace). Vectorization refers to converting tokens into numbers so they can be fed into a neural network. All of these tasks can be accomplished with this layer.\n",
|
400 | 400 | "\n",
|
|
580 | 580 | "\n",
|
581 | 581 | "`.cache()` keeps data in memory after it's loaded off disk. This will ensure the dataset does not become a bottleneck while training your model. If your dataset is too large to fit into memory, you can also use this method to create a performant on-disk cache, which is more efficient to read than many small files.\n",
|
582 | 582 | "\n",
|
583 |
| - "`.prefetch()` overlaps data preprocessing and model execution while training. \n", |
| 583 | + "`.prefetch()` overlaps data preprocessing and model execution while training.\n", |
584 | 584 | "\n",
|
585 | 585 | "You can learn more about both methods, as well as how to cache data to disk in the [data performance guide](https://www.tensorflow.org/guide/data_performance)."
|
586 | 586 | ]
|
|
635 | 635 | " layers.Dropout(0.2),\n",
|
636 | 636 | " layers.GlobalAveragePooling1D(),\n",
|
637 | 637 | " layers.Dropout(0.2),\n",
|
638 |
| - " layers.Dense(1)])\n", |
| 638 | + " layers.Dense(1, activation='sigmoid')])\n", |
639 | 639 | "\n",
|
640 | 640 | "model.summary()"
|
641 | 641 | ]
|
|
674 | 674 | },
|
675 | 675 | "outputs": [],
|
676 | 676 | "source": [
|
677 |
| - "model.compile(loss=losses.BinaryCrossentropy(from_logits=True),\n", |
| 677 | + "model.compile(loss=losses.BinaryCrossentropy(),\n", |
678 | 678 | " optimizer='adam',\n",
|
679 |
| - " metrics=tf.metrics.BinaryAccuracy(threshold=0.0))" |
| 679 | + " metrics=[tf.metrics.BinaryAccuracy(threshold=0.5)])" |
680 | 680 | ]
|
681 | 681 | },
|
682 | 682 | {
|
|
884 | 884 | },
|
885 | 885 | "outputs": [],
|
886 | 886 | "source": [
|
887 |
| - "examples = [\n", |
| 887 | + "examples = tf.constant([\n", |
888 | 888 | " \"The movie was great!\",\n",
|
889 | 889 | " \"The movie was okay.\",\n",
|
890 | 890 | " \"The movie was terrible...\"\n",
|
891 |
| - "]\n", |
| 891 | + "])\n", |
892 | 892 | "\n",
|
893 | 893 | "export_model.predict(examples)"
|
894 | 894 | ]
|
|
916 | 916 | "\n",
|
917 | 917 | "This tutorial showed how to train a binary classifier from scratch on the IMDB dataset. As an exercise, you can modify this notebook to train a multi-class classifier to predict the tag of a programming question on [Stack Overflow](http://stackoverflow.com/).\n",
|
918 | 918 | "\n",
|
919 |
| - "A [dataset](https://storage.googleapis.com/download.tensorflow.org/data/stack_overflow_16k.tar.gz) has been prepared for you to use containing the body of several thousand programming questions (for example, \"How can I sort a dictionary by value in Python?\") posted to Stack Overflow. Each of these is labeled with exactly one tag (either Python, CSharp, JavaScript, or Java). Your task is to take a question as input, and predict the appropriate tag, in this case, Python. \n", |
| 919 | + "A [dataset](https://storage.googleapis.com/download.tensorflow.org/data/stack_overflow_16k.tar.gz) has been prepared for you to use containing the body of several thousand programming questions (for example, \"How can I sort a dictionary by value in Python?\") posted to Stack Overflow. Each of these is labeled with exactly one tag (either Python, CSharp, JavaScript, or Java). Your task is to take a question as input, and predict the appropriate tag, in this case, Python.\n", |
920 | 920 | "\n",
|
921 | 921 | "The dataset you will work with contains several thousand questions extracted from the much larger public Stack Overflow dataset on [BigQuery](https://console.cloud.google.com/marketplace/details/stack-exchange/stack-overflow), which contains more than 17 million posts.\n",
|
922 | 922 | "\n",
|
|
950 | 950 | "\n",
|
951 | 951 | "1. When plotting accuracy over time, change `binary_accuracy` and `val_binary_accuracy` to `accuracy` and `val_accuracy`, respectively.\n",
|
952 | 952 | "\n",
|
953 |
| - "1. Once these changes are complete, you will be able to train a multi-class classifier. " |
| 953 | + "1. Once these changes are complete, you will be able to train a multi-class classifier." |
954 | 954 | ]
|
955 | 955 | },
|
956 | 956 | {
|
|
968 | 968 | "metadata": {
|
969 | 969 | "accelerator": "GPU",
|
970 | 970 | "colab": {
|
971 |
| - "collapsed_sections": [], |
972 | 971 | "name": "text_classification.ipynb",
|
| 972 | + "provenance": [], |
973 | 973 | "toc_visible": true
|
974 | 974 | },
|
975 | 975 | "kernelspec": {
|
|
0 commit comments