|
68 | 68 | "* Replace the classifier head with the number of labels of a new dataset\n",
|
69 | 69 | "* Perform transfer learning on the [UCF101 dataset](https://www.crcv.ucf.edu/data/UCF101.php)\n",
|
70 | 70 | "\n",
|
71 |
| - "The model downloaded in this tutorial is from [official/projects/movinet](https://github.com/tensorflow/models/tree/master/official/projects/movinet). This repository contains a collection of MoViNet models that TF Hub uses in the TensorFlow 2 SavedModel format." |
| 71 | + "The model downloaded in this tutorial is from [official/projects/movinet](https://github.com/tensorflow/models/tree/master/official/projects/movinet). This repository contains a collection of MoViNet models that TF Hub uses in the TensorFlow 2 SavedModel format.\n", |
| 72 | + "\n", |
| 73 | + "This transfer learning tutorial is the third part in a series of TensorFlow video tutorials. Here are the other three tutorials:\n", |
| 74 | + "\n", |
| 75 | + "- [Load video data](https://www.tensorflow.org/tutorials/load_data/video): This tutorial explains much of the code used in this document; in particular, how to preprocess and load data through the `FrameGenerator` class is explained in more detail.\n", |
| 76 | + "- [Build a 3D CNN model for video classification](https://www.tensorflow.org/tutorials/video/video_classification). Note that this tutorial uses a (2+1)D CNN that decomposes the spatial and temporal aspects of 3D data; if you are using volumetric data such as an MRI scan, consider using a 3D CNN instead of a (2+1)D CNN.\n", |
| 77 | + "- [MoViNet for streaming action recognition](https://www.tensorflow.org/hub/tutorials/movinet): Get familiar with the MoViNet models that are available on TF Hub." |
72 | 78 | ]
|
73 | 79 | },
|
74 | 80 | {
|
|
111 | 117 | "import cv2\n",
|
112 | 118 | "import numpy as np\n",
|
113 | 119 | "import remotezip as rz\n",
|
| 120 | + "import seaborn as sns\n", |
114 | 121 | "import matplotlib.pyplot as plt\n",
|
115 | 122 | "\n",
|
116 | 123 | "import keras\n",
|
|
132 | 139 | },
|
133 | 140 | "source": [
|
134 | 141 | "## Load data\n",
|
135 |
| - "\n", |
| 142 | + " \n", |
136 | 143 | "The hidden cell below defines helper functions to download a slice of data from the UCF-101 dataset, and load it into a `tf.data.Dataset`. The [Loading video data tutorial](https://www.tensorflow.org/tutorials/load_data/video) provides a detailed walkthrough of this code.\n",
|
137 | 144 | "\n",
|
138 | 145 | "The `FrameGenerator` class at the end of the hidden block is the most important utility here. It creates an iterable object that can feed data into the TensorFlow data pipeline. Specifically, this class contains a Python generator that loads the video frames along with its encoded label. The generator (`__call__`) function yields the frame array produced by `frames_from_video_file` and a one-hot encoded vector of the label associated with the set of frames.\n",
|
|
598 | 605 | " verbose=1)"
|
599 | 606 | ]
|
600 | 607 | },
|
| 608 | + { |
| 609 | + "cell_type": "markdown", |
| 610 | + "metadata": { |
| 611 | + "id": "KkLl2zF8G9W0" |
| 612 | + }, |
| 613 | + "source": [ |
| 614 | + "## Evaluate the model\n", |
| 615 | + "\n", |
| 616 | + "The model achieved high accuracy on the training dataset. Next, use Keras `Model.evaluate` to evaluate it on the test set." |
| 617 | + ] |
| 618 | + }, |
| 619 | + { |
| 620 | + "cell_type": "code", |
| 621 | + "execution_count": null, |
| 622 | + "metadata": { |
| 623 | + "id": "NqgbzOiKuxxT" |
| 624 | + }, |
| 625 | + "outputs": [], |
| 626 | + "source": [ |
| 627 | + "model.evaluate(test_ds, return_dict=True)" |
| 628 | + ] |
| 629 | + }, |
| 630 | + { |
| 631 | + "cell_type": "markdown", |
| 632 | + "metadata": { |
| 633 | + "id": "OkFst2gsHBwD" |
| 634 | + }, |
| 635 | + "source": [ |
| 636 | + "To visualize model performance further, use a [confusion matrix](https://www.tensorflow.org/api_docs/python/tf/math/confusion_matrix). The confusion matrix allows you to assess the performance of the classification model beyond accuracy. In order to build the confusion matrix for this multi-class classification problem, get the actual values in the test set and the predicted values." |
| 637 | + ] |
| 638 | + }, |
| 639 | + { |
| 640 | + "cell_type": "code", |
| 641 | + "execution_count": null, |
| 642 | + "metadata": { |
| 643 | + "id": "hssSdW9XHF_j" |
| 644 | + }, |
| 645 | + "outputs": [], |
| 646 | + "source": [ |
| 647 | + "def get_actual_predicted_labels(dataset):\n", |
| 648 | + " \"\"\"\n", |
| 649 | + " Create a list of actual ground truth values and the predictions from the model.\n", |
| 650 | + "\n", |
| 651 | + " Args:\n", |
| 652 | + " dataset: An iterable data structure, such as a TensorFlow Dataset, with features and labels.\n", |
| 653 | + "\n", |
| 654 | + " Return:\n", |
| 655 | + " Ground truth and predicted values for a particular dataset.\n", |
| 656 | + " \"\"\"\n", |
| 657 | + " actual = [labels for _, labels in dataset.unbatch()]\n", |
| 658 | + " predicted = model.predict(dataset)\n", |
| 659 | + "\n", |
| 660 | + " actual = tf.stack(actual, axis=0)\n", |
| 661 | + " predicted = tf.concat(predicted, axis=0)\n", |
| 662 | + " predicted = tf.argmax(predicted, axis=1)\n", |
| 663 | + "\n", |
| 664 | + " return actual, predicted" |
| 665 | + ] |
| 666 | + }, |
| 667 | + { |
| 668 | + "cell_type": "code", |
| 669 | + "execution_count": null, |
| 670 | + "metadata": { |
| 671 | + "id": "2TmTue6THGWO" |
| 672 | + }, |
| 673 | + "outputs": [], |
| 674 | + "source": [ |
| 675 | + "def plot_confusion_matrix(actual, predicted, labels, ds_type):\n", |
| 676 | + " cm = tf.math.confusion_matrix(actual, predicted)\n", |
| 677 | + " ax = sns.heatmap(cm, annot=True, fmt='g')\n", |
| 678 | + " sns.set(rc={'figure.figsize':(12, 12)})\n", |
| 679 | + " sns.set(font_scale=1.4)\n", |
| 680 | + " ax.set_title('Confusion matrix of action recognition for ' + ds_type)\n", |
| 681 | + " ax.set_xlabel('Predicted Action')\n", |
| 682 | + " ax.set_ylabel('Actual Action')\n", |
| 683 | + " plt.xticks(rotation=90)\n", |
| 684 | + " plt.yticks(rotation=0)\n", |
| 685 | + " ax.xaxis.set_ticklabels(labels)\n", |
| 686 | + " ax.yaxis.set_ticklabels(labels)" |
| 687 | + ] |
| 688 | + }, |
| 689 | + { |
| 690 | + "cell_type": "code", |
| 691 | + "execution_count": null, |
| 692 | + "metadata": { |
| 693 | + "id": "4RK1A1C1HH6V" |
| 694 | + }, |
| 695 | + "outputs": [], |
| 696 | + "source": [ |
| 697 | + "fg = FrameGenerator(subset_paths['train'], num_frames, training = True)\n", |
| 698 | + "label_names = list(fg.class_ids_for_name.keys())" |
| 699 | + ] |
| 700 | + }, |
| 701 | + { |
| 702 | + "cell_type": "code", |
| 703 | + "execution_count": null, |
| 704 | + "metadata": { |
| 705 | + "id": "r4AFi2e5HKEO" |
| 706 | + }, |
| 707 | + "outputs": [], |
| 708 | + "source": [ |
| 709 | + "actual, predicted = get_actual_predicted_labels(test_ds)\n", |
| 710 | + "plot_confusion_matrix(actual, predicted, label_names, 'test')" |
| 711 | + ] |
| 712 | + }, |
601 | 713 | {
|
602 | 714 | "cell_type": "markdown",
|
603 | 715 | "metadata": {
|
|
611 | 723 | "\n",
|
612 | 724 | "In particular, using the `FrameGenerator` class used in this tutorial and the other video data and classification tutorials will help you load data into your models.\n",
|
613 | 725 | "\n",
|
614 |
| - "To learn more about video data, check out:\n", |
| 726 | + "To learn more about working with video data in TensorFlow, check out the following tutorials:\n", |
615 | 727 | "\n",
|
616 |
| - "- [Load video data](https://www.tensorflow.org/tutorials/load_data/video): This tutorial explains much of the code used in this document.\n", |
617 |
| - "- [Build a 3D CNN model for video classification](https://www.tensorflow.org/tutorials/video/video_classification). Note that this tutorial uses a (2+1)D CNN that decomposes the spatial and temporal aspects of 3D data; if you are using volumetric data such as an MRI scan, consider using a 3D CNN instead of a (2+1)D CNN." |
| 728 | + "* [Load video data](https://www.tensorflow.org/tutorials/load_data/video)\n", |
| 729 | + "* [Build a 3D CNN model for video classification](https://www.tensorflow.org/tutorials/video/video_classification)\n", |
| 730 | + "* [MoViNet for streaming action recognition](https://www.tensorflow.org/hub/tutorials/movinet)" |
618 | 731 | ]
|
619 | 732 | }
|
620 | 733 | ],
|
|
0 commit comments