|
37 | 37 | "id": "qFdPvlXBOdUN"
|
38 | 38 | },
|
39 | 39 | "source": [
|
40 |
| - "# Use TensorFlow Models: Fine tune a ResNet" |
| 40 | + "# Image classification with Model Garden" |
41 | 41 | ]
|
42 | 42 | },
|
43 | 43 | {
|
|
48 | 48 | "source": [
|
49 | 49 | "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
|
50 | 50 | " <td>\n",
|
51 |
| - " <a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/images/models_vision\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n", |
| 51 | + " <a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/images/classification_with_model_garden\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n", |
52 | 52 | " </td>\n",
|
53 | 53 | " <td>\n",
|
54 |
| - " <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/images/models_vision.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n", |
| 54 | + " <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/images/classification_with_model_garden.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n", |
55 | 55 | " </td>\n",
|
56 | 56 | " <td>\n",
|
57 |
| - " <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/images/models_vision.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View on GitHub</a>\n", |
| 57 | + " <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/images/classification_with_model_garden.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View on GitHub</a>\n", |
58 | 58 | " </td>\n",
|
59 | 59 | " <td>\n",
|
60 |
| - " <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/images/models_vision.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n", |
| 60 | + " <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/images/classification_with_model_garden.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n", |
61 | 61 | " </td>\n",
|
62 | 62 | "</table>"
|
63 | 63 | ]
|
|
68 | 68 | "id": "Ta_nFXaVAqLD"
|
69 | 69 | },
|
70 | 70 | "source": [
|
71 |
| - "This tutorial uses the TensorFlow Models package to fine-tune a ResNet." |
| 71 | + "This tutorial fine-tunes a Residual Network (ResNet) from the TensorFlow [Model Garden](https://github.com/tensorflow/models) package (`tensorflow-models`) to classify images in the [CIFAR](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.\n", |
| 72 | + "\n", |
| 73 | + "Model Garden contains a collection of state-of-the-art vision models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.\n", |
| 74 | + "\n", |
| 75 | + "This tutorial uses a [ResNet](https://arxiv.org/pdf/1512.03385.pdf) model, a state-of-the-art image classifier. This tutorial uses the ResNet-18 model, a convolutional neural network with 18 layers.\n", |
| 76 | + "\n", |
| 77 | + "This tutorial demonstrates how to:\n", |
| 78 | + "1. Use models from the TensorFlow Models package.\n", |
| 79 | + "2. Fine-tune a pre-built ResNet for image classification.\n", |
| 80 | + "3. Export the tuned ResNet model." |
72 | 81 | ]
|
73 | 82 | },
|
74 | 83 | {
|
|
79 | 88 | "source": [
|
80 | 89 | "## Setup\n",
|
81 | 90 | "\n",
|
82 |
| - "Install and import the necessary modules" |
| 91 | + "Install and import the necessary modules. This tutorial uses the `tf-models-nightly` version of Model Garden." |
83 | 92 | ]
|
84 | 93 | },
|
85 | 94 | {
|
|
94 | 103 | "!pip install -q tf-models-nightly"
|
95 | 104 | ]
|
96 | 105 | },
|
| 106 | + { |
| 107 | + "cell_type": "markdown", |
| 108 | + "metadata": { |
| 109 | + "id": "CKYMTPjOE400" |
| 110 | + }, |
| 111 | + "source": [ |
| 112 | + "Import TensorFlow, TensorFlow Datasets, and a few helper libraries." |
| 113 | + ] |
| 114 | + }, |
97 | 115 | {
|
98 | 116 | "cell_type": "code",
|
99 | 117 | "execution_count": null,
|
|
102 | 120 | },
|
103 | 121 | "outputs": [],
|
104 | 122 | "source": [
|
105 |
| - "# Import helper libraries\n", |
106 | 123 | "import pprint\n",
|
107 | 124 | "import tempfile\n",
|
108 | 125 | "\n",
|
|
113 | 130 | "import tensorflow_datasets as tfds"
|
114 | 131 | ]
|
115 | 132 | },
|
| 133 | + { |
| 134 | + "cell_type": "markdown", |
| 135 | + "metadata": { |
| 136 | + "id": "AVTs0jDd1b24" |
| 137 | + }, |
| 138 | + "source": [ |
| 139 | + "The `tensorflow_models` package contains the ResNet vision model, and the `official.vision.serving` model contains the function to save and export the tuned model." |
| 140 | + ] |
| 141 | + }, |
116 | 142 | {
|
117 | 143 | "cell_type": "code",
|
118 | 144 | "execution_count": null,
|
|
124 | 150 | "import tensorflow_models as tfm\n",
|
125 | 151 | "\n",
|
126 | 152 | "# Not in the tfm public API for v2.9. Will be available as `vision.serving` in v2.10\n",
|
127 |
| - "from official.vision.serving import export_saved_model_lib " |
| 153 | + "from official.vision.serving import export_saved_model_lib" |
128 | 154 | ]
|
129 | 155 | },
|
130 | 156 | {
|
|
133 | 159 | "id": "aKv3wdqkQ8FU"
|
134 | 160 | },
|
135 | 161 | "source": [
|
136 |
| - "## Cifar-10 with ResNet-18 Backbone" |
| 162 | + "## Configure the ResNet-18 model for the Cifar-10 dataset" |
137 | 163 | ]
|
138 | 164 | },
|
139 | 165 | {
|
|
142 | 168 | "id": "5iN8mHEJjKYE"
|
143 | 169 | },
|
144 | 170 | "source": [
|
145 |
| - "Base the experiment on `\"resnet_imagenet\"` configuration (defined by `tfm.vision.configs.image_classification.image_classification_imagenet`)." |
| 171 | + "The CIFAR10 dataset contains 60,000 color images in mutually exclusive 10 classes, with 6,000 images in each class.\n", |
| 172 | + "\n", |
| 173 | + "In Model Garden, the collections of parameters that define a model are called *configs*. Model Garden can create a config based on a known set of parameters via a [factory](https://en.wikipedia.org/wiki/Factory_method_pattern).\n", |
| 174 | + "\n", |
| 175 | + "Use the `resnet_imagenet` factory configuration, as defined by `tfm.vision.configs.image_classification.image_classification_imagenet`. The configuration is set up to train ResNet to converge on [ImageNet](https://www.image-net.org/)." |
146 | 176 | ]
|
147 | 177 | },
|
148 | 178 | {
|
|
165 | 195 | "id": "U6PVwXA-j3E7"
|
166 | 196 | },
|
167 | 197 | "source": [
|
168 |
| - "Next adjust the configuration so that it works with `cifar10`." |
| 198 | + "Adjust the model and dataset configurations so that it works with Cifar-10 (`cifar10`)." |
169 | 199 | ]
|
170 | 200 | },
|
171 | 201 | {
|
|
176 | 206 | },
|
177 | 207 | "outputs": [],
|
178 | 208 | "source": [
|
179 |
| - "# Change model\n", |
| 209 | + "# Configure model\n", |
180 | 210 | "exp_config.task.model.num_classes = 10\n",
|
181 | 211 | "exp_config.task.model.input_size = list(ds_info.features[\"image\"].shape)\n",
|
182 | 212 | "exp_config.task.model.backbone.resnet.model_id = 18\n",
|
183 | 213 | "\n",
|
184 |
| - "# Change train, eval data\n", |
| 214 | + "# Configure training and testing data\n", |
185 | 215 | "batch_size = 128\n",
|
186 | 216 | "\n",
|
187 | 217 | "exp_config.task.train_data.input_path = ''\n",
|
|
201 | 231 | "id": "DE3ggKzzTD56"
|
202 | 232 | },
|
203 | 233 | "source": [
|
204 |
| - "Adjust the trainer configuration:" |
| 234 | + "Adjust the trainer configuration." |
205 | 235 | ]
|
206 | 236 | },
|
207 | 237 | {
|
|
212 | 242 | },
|
213 | 243 | "outputs": [],
|
214 | 244 | "source": [
|
215 |
| - "# Change trainer config\n", |
216 |
| - "train_steps = 5000\n", |
| 245 | + "logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n", |
| 246 | + "\n", |
| 247 | + "if 'GPU' in ''.join(logical_device_names):\n", |
| 248 | + " print('This may be broken in Colab.')\n", |
| 249 | + " device = 'GPU'\n", |
| 250 | + "elif 'TPU' in ''.join(logical_device_names):\n", |
| 251 | + " print('This may be broken in Colab.')\n", |
| 252 | + " device = 'TPU'\n", |
| 253 | + "else:\n", |
| 254 | + " print('This is slow, and doesn\\'t train to convergence.')\n", |
| 255 | + " device = 'CPU'\n", |
| 256 | + "\n", |
| 257 | + "if device=='CPU':\n", |
| 258 | + " train_steps = 20\n", |
| 259 | + " exp_config.trainer.steps_per_loop = 5\n", |
| 260 | + "else:\n", |
| 261 | + " train_steps=5000\n", |
| 262 | + " exp_config.trainer.steps_per_loop = 100\n", |
217 | 263 | "\n",
|
218 | 264 | "exp_config.trainer.steps_per_loop = 100\n",
|
219 | 265 | "exp_config.trainer.summary_interval = 100\n",
|
|
233 | 279 | "id": "5mTcDnBiTOYD"
|
234 | 280 | },
|
235 | 281 | "source": [
|
236 |
| - "And set the runtime configuration." |
| 282 | + "Print the modified configuration." |
237 | 283 | ]
|
238 | 284 | },
|
239 | 285 | {
|
|
255 | 301 | "id": "w7_X0UHaRF2m"
|
256 | 302 | },
|
257 | 303 | "source": [
|
258 |
| - "Set up the distribution strategy:" |
| 304 | + "Set up the distribution strategy." |
259 | 305 | ]
|
260 | 306 | },
|
261 | 307 | {
|
|
288 | 334 | "id": "W4k5YH5pTjaK"
|
289 | 335 | },
|
290 | 336 | "source": [
|
291 |
| - "Create the `Task` object (ref: `tfm.core.base_task.Task`) form the `config_definitions.TaskConfig`:" |
| 337 | + "Create the `Task` object (`tfm.core.base_task.Task`) from the `config_definitions.TaskConfig`.\n", |
| 338 | + "\n", |
| 339 | + "The `Task` object has all the methods necessary for building the dataset, building the model, and running training & evaluation. These methods are driven by `tfm.core.train_lib.run_experiment`." |
292 | 340 | ]
|
293 | 341 | },
|
294 | 342 | {
|
|
326 | 374 | "id": "yrwxnGDaRU0U"
|
327 | 375 | },
|
328 | 376 | "source": [
|
329 |
| - "## Visualize Training Dataloader" |
| 377 | + "## Visualize the training data" |
330 | 378 | ]
|
331 | 379 | },
|
332 | 380 | {
|
|
335 | 383 | "id": "683c255c6c52"
|
336 | 384 | },
|
337 | 385 | "source": [
|
338 |
| - "The data-loader applies a z-score normalization using \n", |
339 |
| - "`preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`, so the images returned by the dataset can't be directly displayed by standard tools, so rescale the minimum to 0.0 and the maximum to 1.0: " |
| 386 | + "The dataloader applies a z-score normalization using \n", |
| 387 | + "`preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`, so the images returned by the dataset can't be directly displayed by standard tools. The visualization code needs to rescale the data into the [0,1] range." |
340 | 388 | ]
|
341 | 389 | },
|
342 | 390 | {
|
|
356 | 404 | "id": "7a8582ebde7b"
|
357 | 405 | },
|
358 | 406 | "source": [
|
359 |
| - "You can use the `tfds.core.DatasetInfo` (`ds_info` from earlier) to lookup the text descriptions of each class ID. " |
| 407 | + "Use `ds_info` (which is an instance of `tfds.core.DatasetInfo`) to lookup the text descriptions of each class ID." |
360 | 408 | ]
|
361 | 409 | },
|
362 | 410 | {
|
|
377 | 425 | "id": "8c652a6fdbcf"
|
378 | 426 | },
|
379 | 427 | "source": [
|
380 |
| - "Use these to disualize a batch of the data:" |
| 428 | + "Visualize a batch of the data." |
381 | 429 | ]
|
382 | 430 | },
|
383 | 431 | {
|
|
427 | 475 | "id": "v_A9VnL2RbXP"
|
428 | 476 | },
|
429 | 477 | "source": [
|
430 |
| - "## Visualize Evaluation Dataloader" |
| 478 | + "## Visualize the testing data" |
| 479 | + ] |
| 480 | + }, |
| 481 | + { |
| 482 | + "cell_type": "markdown", |
| 483 | + "metadata": { |
| 484 | + "id": "AXovuumW_I2z" |
| 485 | + }, |
| 486 | + "source": [ |
| 487 | + "Visualize a batch of images from the validation dataset." |
431 | 488 | ]
|
432 | 489 | },
|
433 | 490 | {
|
|
449 | 506 | "id": "ihKJt2FHRi2N"
|
450 | 507 | },
|
451 | 508 | "source": [
|
452 |
| - "## Train and Evaluate" |
| 509 | + "## Train and evaluate" |
453 | 510 | ]
|
454 | 511 | },
|
455 | 512 | {
|
|
480 | 537 | "tf.keras.utils.plot_model(model, show_shapes=True)"
|
481 | 538 | ]
|
482 | 539 | },
|
| 540 | + { |
| 541 | + "cell_type": "markdown", |
| 542 | + "metadata": { |
| 543 | + "id": "L7nVfxlBA8Gb" |
| 544 | + }, |
| 545 | + "source": [ |
| 546 | + "Print the `accuracy`, `top_5_accuracy`, and `validation_loss` evaluation metrics." |
| 547 | + ] |
| 548 | + }, |
483 | 549 | {
|
484 | 550 | "cell_type": "code",
|
485 | 551 | "execution_count": null,
|
|
492 | 558 | " print(f'{key:20}: {value.numpy():.3f}')"
|
493 | 559 | ]
|
494 | 560 | },
|
| 561 | + { |
| 562 | + "cell_type": "markdown", |
| 563 | + "metadata": { |
| 564 | + "id": "TDys5bZ1zsml" |
| 565 | + }, |
| 566 | + "source": [ |
| 567 | + "Run a batch of the processed training data through the model, and view the results" |
| 568 | + ] |
| 569 | + }, |
| 570 | + { |
| 571 | + "cell_type": "code", |
| 572 | + "execution_count": null, |
| 573 | + "metadata": { |
| 574 | + "id": "GhI7zR-Uz1JT" |
| 575 | + }, |
| 576 | + "outputs": [], |
| 577 | + "source": [ |
| 578 | + "for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n", |
| 579 | + " predictions = model.predict(images)\n", |
| 580 | + " predictions = tf.argmax(predictions, axis=-1)\n", |
| 581 | + "\n", |
| 582 | + "show_batch(images, labels, tf.cast(predictions, tf.int32))\n", |
| 583 | + "\n", |
| 584 | + "if device=='CPU':\n", |
| 585 | + " plt.title('The model was only trained for a few steps, so it is not expected to do well.')" |
| 586 | + ] |
| 587 | + }, |
495 | 588 | {
|
496 | 589 | "cell_type": "markdown",
|
497 | 590 | "metadata": {
|
|
507 | 600 | "id": "9669d08c91af"
|
508 | 601 | },
|
509 | 602 | "source": [
|
510 |
| - "The `keras.Model` object returned by `train_lib.run_experiment` expects the data to be normalized by the dataset loader using the same mean and variance statiscics in `preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`. This export function handles those details so you can pass `tf.uint8` images and get correct result.\n" |
| 603 | + "The `keras.Model` object returned by `train_lib.run_experiment` expects the data to be normalized by the dataset loader using the same mean and variance statiscics in `preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`. This export function handles those details, so you can pass `tf.uint8` images and get the correct results.\n" |
511 | 604 | ]
|
512 | 605 | },
|
513 | 606 | {
|
|
534 | 627 | "id": "vVr6DxNqTyLZ"
|
535 | 628 | },
|
536 | 629 | "source": [
|
537 |
| - "Test the exported model" |
| 630 | + "Test the exported model." |
538 | 631 | ]
|
539 | 632 | },
|
540 | 633 | {
|
|
556 | 649 | "id": "GiOp2WVIUNUZ"
|
557 | 650 | },
|
558 | 651 | "source": [
|
559 |
| - "Visualize the predictions" |
| 652 | + "Visualize the predictions." |
560 | 653 | ]
|
561 | 654 | },
|
562 | 655 | {
|
|
573 | 666 | " for image in data['image']:\n",
|
574 | 667 | " index = tf.argmax(model_fn(image[tf.newaxis, ...])['logits'], axis=1)[0]\n",
|
575 | 668 | " predictions.append(index)\n",
|
576 |
| - " show_batch(data['image'], data['label'], predictions)" |
| 669 | + " show_batch(data['image'], data['label'], predictions)\n", |
| 670 | + "\n", |
| 671 | + " if device=='CPU':\n", |
| 672 | + " plt.title('The model was only trained for a few steps, it is not expected to do well.')" |
577 | 673 | ]
|
578 | 674 | }
|
579 | 675 | ],
|
|
0 commit comments