Skip to content

Commit 27821b6

Browse files
ktonthatcopybara-github
authored andcommitted
Notebook documentation for classification_with_model_garden
PiperOrigin-RevId: 447557463
1 parent 7d5ea2e commit 27821b6

File tree

2 files changed

+130
-31
lines changed

2 files changed

+130
-31
lines changed

site/en/tutorials/_toc.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ toc:
109109
path: /tutorials/images/data_augmentation
110110
- title: "Image segmentation"
111111
path: /tutorials/images/segmentation
112+
- title: "Image classification with Model Garden"
113+
path: /tutorials/images/classification_with_model_garden
114+
status: new
112115
- title: "Object detection with TF Hub"
113116
path: https://github.com/tensorflow/hub/blob/master/examples/colab/tf2_object_detection.ipynb
114117
status: external

site/en/tutorials/images/models_vision.ipynb renamed to site/en/tutorials/images/classification_with_model_garden.ipynb

Lines changed: 127 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
"id": "qFdPvlXBOdUN"
3838
},
3939
"source": [
40-
"# Use TensorFlow Models: Fine tune a ResNet"
40+
"# Image classification with Model Garden"
4141
]
4242
},
4343
{
@@ -48,16 +48,16 @@
4848
"source": [
4949
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
5050
" <td>\n",
51-
" <a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/images/models_vision\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
51+
" <a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/images/classification_with_model_garden\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
5252
" </td>\n",
5353
" <td>\n",
54-
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/images/models_vision.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
54+
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/images/classification_with_model_garden.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
5555
" </td>\n",
5656
" <td>\n",
57-
" <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/images/models_vision.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View on GitHub</a>\n",
57+
" <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/images/classification_with_model_garden.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View on GitHub</a>\n",
5858
" </td>\n",
5959
" <td>\n",
60-
" <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/images/models_vision.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
60+
" <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/images/classification_with_model_garden.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
6161
" </td>\n",
6262
"</table>"
6363
]
@@ -68,7 +68,16 @@
6868
"id": "Ta_nFXaVAqLD"
6969
},
7070
"source": [
71-
"This tutorial uses the TensorFlow Models package to fine-tune a ResNet."
71+
"This tutorial fine-tunes a Residual Network (ResNet) from the TensorFlow [Model Garden](https://github.com/tensorflow/models) package (`tensorflow-models`) to classify images in the [CIFAR](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.\n",
72+
"\n",
73+
"Model Garden contains a collection of state-of-the-art vision models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.\n",
74+
"\n",
75+
"This tutorial uses a [ResNet](https://arxiv.org/pdf/1512.03385.pdf) model, a state-of-the-art image classifier. This tutorial uses the ResNet-18 model, a convolutional neural network with 18 layers.\n",
76+
"\n",
77+
"This tutorial demonstrates how to:\n",
78+
"1. Use models from the TensorFlow Models package.\n",
79+
"2. Fine-tune a pre-built ResNet for image classification.\n",
80+
"3. Export the tuned ResNet model."
7281
]
7382
},
7483
{
@@ -79,7 +88,7 @@
7988
"source": [
8089
"## Setup\n",
8190
"\n",
82-
"Install and import the necessary modules"
91+
"Install and import the necessary modules. This tutorial uses the `tf-models-nightly` version of Model Garden."
8392
]
8493
},
8594
{
@@ -94,6 +103,15 @@
94103
"!pip install -q tf-models-nightly"
95104
]
96105
},
106+
{
107+
"cell_type": "markdown",
108+
"metadata": {
109+
"id": "CKYMTPjOE400"
110+
},
111+
"source": [
112+
"Import TensorFlow, TensorFlow Datasets, and a few helper libraries."
113+
]
114+
},
97115
{
98116
"cell_type": "code",
99117
"execution_count": null,
@@ -102,7 +120,6 @@
102120
},
103121
"outputs": [],
104122
"source": [
105-
"# Import helper libraries\n",
106123
"import pprint\n",
107124
"import tempfile\n",
108125
"\n",
@@ -113,6 +130,15 @@
113130
"import tensorflow_datasets as tfds"
114131
]
115132
},
133+
{
134+
"cell_type": "markdown",
135+
"metadata": {
136+
"id": "AVTs0jDd1b24"
137+
},
138+
"source": [
139+
"The `tensorflow_models` package contains the ResNet vision model, and the `official.vision.serving` model contains the function to save and export the tuned model."
140+
]
141+
},
116142
{
117143
"cell_type": "code",
118144
"execution_count": null,
@@ -124,7 +150,7 @@
124150
"import tensorflow_models as tfm\n",
125151
"\n",
126152
"# Not in the tfm public API for v2.9. Will be available as `vision.serving` in v2.10\n",
127-
"from official.vision.serving import export_saved_model_lib "
153+
"from official.vision.serving import export_saved_model_lib"
128154
]
129155
},
130156
{
@@ -133,7 +159,7 @@
133159
"id": "aKv3wdqkQ8FU"
134160
},
135161
"source": [
136-
"## Cifar-10 with ResNet-18 Backbone"
162+
"## Configure the ResNet-18 model for the Cifar-10 dataset"
137163
]
138164
},
139165
{
@@ -142,7 +168,11 @@
142168
"id": "5iN8mHEJjKYE"
143169
},
144170
"source": [
145-
"Base the experiment on `\"resnet_imagenet\"` configuration (defined by `tfm.vision.configs.image_classification.image_classification_imagenet`)."
171+
"The CIFAR10 dataset contains 60,000 color images in mutually exclusive 10 classes, with 6,000 images in each class.\n",
172+
"\n",
173+
"In Model Garden, the collections of parameters that define a model are called *configs*. Model Garden can create a config based on a known set of parameters via a [factory](https://en.wikipedia.org/wiki/Factory_method_pattern).\n",
174+
"\n",
175+
"Use the `resnet_imagenet` factory configuration, as defined by `tfm.vision.configs.image_classification.image_classification_imagenet`. The configuration is set up to train ResNet to converge on [ImageNet](https://www.image-net.org/)."
146176
]
147177
},
148178
{
@@ -165,7 +195,7 @@
165195
"id": "U6PVwXA-j3E7"
166196
},
167197
"source": [
168-
"Next adjust the configuration so that it works with `cifar10`."
198+
"Adjust the model and dataset configurations so that it works with Cifar-10 (`cifar10`)."
169199
]
170200
},
171201
{
@@ -176,12 +206,12 @@
176206
},
177207
"outputs": [],
178208
"source": [
179-
"# Change model\n",
209+
"# Configure model\n",
180210
"exp_config.task.model.num_classes = 10\n",
181211
"exp_config.task.model.input_size = list(ds_info.features[\"image\"].shape)\n",
182212
"exp_config.task.model.backbone.resnet.model_id = 18\n",
183213
"\n",
184-
"# Change train, eval data\n",
214+
"# Configure training and testing data\n",
185215
"batch_size = 128\n",
186216
"\n",
187217
"exp_config.task.train_data.input_path = ''\n",
@@ -201,7 +231,7 @@
201231
"id": "DE3ggKzzTD56"
202232
},
203233
"source": [
204-
"Adjust the trainer configuration:"
234+
"Adjust the trainer configuration."
205235
]
206236
},
207237
{
@@ -212,8 +242,24 @@
212242
},
213243
"outputs": [],
214244
"source": [
215-
"# Change trainer config\n",
216-
"train_steps = 5000\n",
245+
"logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n",
246+
"\n",
247+
"if 'GPU' in ''.join(logical_device_names):\n",
248+
" print('This may be broken in Colab.')\n",
249+
" device = 'GPU'\n",
250+
"elif 'TPU' in ''.join(logical_device_names):\n",
251+
" print('This may be broken in Colab.')\n",
252+
" device = 'TPU'\n",
253+
"else:\n",
254+
" print('This is slow, and doesn\\'t train to convergence.')\n",
255+
" device = 'CPU'\n",
256+
"\n",
257+
"if device=='CPU':\n",
258+
" train_steps = 20\n",
259+
" exp_config.trainer.steps_per_loop = 5\n",
260+
"else:\n",
261+
" train_steps=5000\n",
262+
" exp_config.trainer.steps_per_loop = 100\n",
217263
"\n",
218264
"exp_config.trainer.steps_per_loop = 100\n",
219265
"exp_config.trainer.summary_interval = 100\n",
@@ -233,7 +279,7 @@
233279
"id": "5mTcDnBiTOYD"
234280
},
235281
"source": [
236-
"And set the runtime configuration."
282+
"Print the modified configuration."
237283
]
238284
},
239285
{
@@ -255,7 +301,7 @@
255301
"id": "w7_X0UHaRF2m"
256302
},
257303
"source": [
258-
"Set up the distribution strategy:"
304+
"Set up the distribution strategy."
259305
]
260306
},
261307
{
@@ -288,7 +334,9 @@
288334
"id": "W4k5YH5pTjaK"
289335
},
290336
"source": [
291-
"Create the `Task` object (ref: `tfm.core.base_task.Task`) form the `config_definitions.TaskConfig`:"
337+
"Create the `Task` object (`tfm.core.base_task.Task`) from the `config_definitions.TaskConfig`.\n",
338+
"\n",
339+
"The `Task` object has all the methods necessary for building the dataset, building the model, and running training & evaluation. These methods are driven by `tfm.core.train_lib.run_experiment`."
292340
]
293341
},
294342
{
@@ -326,7 +374,7 @@
326374
"id": "yrwxnGDaRU0U"
327375
},
328376
"source": [
329-
"## Visualize Training Dataloader"
377+
"## Visualize the training data"
330378
]
331379
},
332380
{
@@ -335,8 +383,8 @@
335383
"id": "683c255c6c52"
336384
},
337385
"source": [
338-
"The data-loader applies a z-score normalization using \n",
339-
"`preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`, so the images returned by the dataset can't be directly displayed by standard tools, so rescale the minimum to 0.0 and the maximum to 1.0: "
386+
"The dataloader applies a z-score normalization using \n",
387+
"`preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`, so the images returned by the dataset can't be directly displayed by standard tools. The visualization code needs to rescale the data into the [0,1] range."
340388
]
341389
},
342390
{
@@ -356,7 +404,7 @@
356404
"id": "7a8582ebde7b"
357405
},
358406
"source": [
359-
"You can use the `tfds.core.DatasetInfo` (`ds_info` from earlier) to lookup the text descriptions of each class ID. "
407+
"Use `ds_info` (which is an instance of `tfds.core.DatasetInfo`) to lookup the text descriptions of each class ID."
360408
]
361409
},
362410
{
@@ -377,7 +425,7 @@
377425
"id": "8c652a6fdbcf"
378426
},
379427
"source": [
380-
"Use these to disualize a batch of the data:"
428+
"Visualize a batch of the data."
381429
]
382430
},
383431
{
@@ -427,7 +475,16 @@
427475
"id": "v_A9VnL2RbXP"
428476
},
429477
"source": [
430-
"## Visualize Evaluation Dataloader"
478+
"## Visualize the testing data"
479+
]
480+
},
481+
{
482+
"cell_type": "markdown",
483+
"metadata": {
484+
"id": "AXovuumW_I2z"
485+
},
486+
"source": [
487+
"Visualize a batch of images from the validation dataset."
431488
]
432489
},
433490
{
@@ -449,7 +506,7 @@
449506
"id": "ihKJt2FHRi2N"
450507
},
451508
"source": [
452-
"## Train and Evaluate"
509+
"## Train and evaluate"
453510
]
454511
},
455512
{
@@ -480,6 +537,15 @@
480537
"tf.keras.utils.plot_model(model, show_shapes=True)"
481538
]
482539
},
540+
{
541+
"cell_type": "markdown",
542+
"metadata": {
543+
"id": "L7nVfxlBA8Gb"
544+
},
545+
"source": [
546+
"Print the `accuracy`, `top_5_accuracy`, and `validation_loss` evaluation metrics."
547+
]
548+
},
483549
{
484550
"cell_type": "code",
485551
"execution_count": null,
@@ -492,6 +558,33 @@
492558
" print(f'{key:20}: {value.numpy():.3f}')"
493559
]
494560
},
561+
{
562+
"cell_type": "markdown",
563+
"metadata": {
564+
"id": "TDys5bZ1zsml"
565+
},
566+
"source": [
567+
"Run a batch of the processed training data through the model, and view the results"
568+
]
569+
},
570+
{
571+
"cell_type": "code",
572+
"execution_count": null,
573+
"metadata": {
574+
"id": "GhI7zR-Uz1JT"
575+
},
576+
"outputs": [],
577+
"source": [
578+
"for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n",
579+
" predictions = model.predict(images)\n",
580+
" predictions = tf.argmax(predictions, axis=-1)\n",
581+
"\n",
582+
"show_batch(images, labels, tf.cast(predictions, tf.int32))\n",
583+
"\n",
584+
"if device=='CPU':\n",
585+
" plt.title('The model was only trained for a few steps, so it is not expected to do well.')"
586+
]
587+
},
495588
{
496589
"cell_type": "markdown",
497590
"metadata": {
@@ -507,7 +600,7 @@
507600
"id": "9669d08c91af"
508601
},
509602
"source": [
510-
"The `keras.Model` object returned by `train_lib.run_experiment` expects the data to be normalized by the dataset loader using the same mean and variance statiscics in `preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`. This export function handles those details so you can pass `tf.uint8` images and get correct result.\n"
603+
"The `keras.Model` object returned by `train_lib.run_experiment` expects the data to be normalized by the dataset loader using the same mean and variance statiscics in `preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`. This export function handles those details, so you can pass `tf.uint8` images and get the correct results.\n"
511604
]
512605
},
513606
{
@@ -534,7 +627,7 @@
534627
"id": "vVr6DxNqTyLZ"
535628
},
536629
"source": [
537-
"Test the exported model"
630+
"Test the exported model."
538631
]
539632
},
540633
{
@@ -556,7 +649,7 @@
556649
"id": "GiOp2WVIUNUZ"
557650
},
558651
"source": [
559-
"Visualize the predictions"
652+
"Visualize the predictions."
560653
]
561654
},
562655
{
@@ -573,7 +666,10 @@
573666
" for image in data['image']:\n",
574667
" index = tf.argmax(model_fn(image[tf.newaxis, ...])['logits'], axis=1)[0]\n",
575668
" predictions.append(index)\n",
576-
" show_batch(data['image'], data['label'], predictions)"
669+
" show_batch(data['image'], data['label'], predictions)\n",
670+
"\n",
671+
" if device=='CPU':\n",
672+
" plt.title('The model was only trained for a few steps, it is not expected to do well.')"
577673
]
578674
}
579675
],

0 commit comments

Comments
 (0)