|
80 | 80 | "\n",
|
81 | 81 | "## What is image segmentation?\n",
|
82 | 82 | "\n",
|
83 |
| - "In an image classification task the network assigns a label (or class) to each input image. However, suppose you want to know the shape of that object, which pixel belongs to which object, etc. In this case you will want to assign a class to each pixel of the image. This task is known as segmentation. A segmentation model returns much more detailed intofmation about the image. Image segmentation has many applications in medical imaging, self-driving cars and satellite imaging to name a few.\n", |
| 83 | + "In an image classification task the network assigns a label (or class) to each input image. However, suppose you want to know the shape of that object, which pixel belongs to which object, etc. In this case you will want to assign a class to each pixel of the image. This task is known as segmentation. A segmentation model returns much more detailed information about the image. Image segmentation has many applications in medical imaging, self-driving cars and satellite imaging to name a few.\n", |
84 | 84 | "\n",
|
85 |
| - "This tutorial uses the [Oxford-IIIT Pet Dataset](https://www.robots.ox.ac.uk/~vgg/data/pets/), created by Parkhi *et al*. The dataset consists of images of 37 pet breeds, with 200 images per breed (~100 each in the train and test split). Each image includes the corresponding labels, and pixel-wise masks. The masks are class-labels for each pixel. Each pixel is given one of three categories :\n", |
| 85 | + "This tutorial uses the [Oxford-IIIT Pet Dataset](https://www.robots.ox.ac.uk/~vgg/data/pets/) ([Parkhi et al, 2012](https://www.robots.ox.ac.uk/~vgg/publications/2012/parkhi12a/parkhi12a.pdf)). The dataset consists of images of 37 pet breeds, with 200 images per breed (~100 each in the training and test splits). Each image includes the corresponding labels, and pixel-wise masks. The masks are class-labels for each pixel. Each pixel is given one of three categories:\n", |
86 | 86 | "\n",
|
87 |
| - "* Class 1 : Pixel belonging to the pet.\n", |
88 |
| - "* Class 2 : Pixel bordering the pet.\n", |
89 |
| - "* Class 3 : None of the above/ Surrounding pixel." |
| 87 | + "- Class 1: Pixel belonging to the pet.\n", |
| 88 | + "- Class 2: Pixel bordering the pet.\n", |
| 89 | + "- Class 3: None of the above/a surrounding pixel." |
90 | 90 | ]
|
91 | 91 | },
|
92 | 92 | {
|
|
196 | 196 | "id": "65-qHTjX5VZh"
|
197 | 197 | },
|
198 | 198 | "source": [
|
199 |
| - "The dataset already contains the required splits of test and train and so continue to use the same split." |
| 199 | + "The dataset already contains the required training and test splits, so continue to use the same splits." |
200 | 200 | ]
|
201 | 201 | },
|
202 | 202 | {
|
|
232 | 232 | },
|
233 | 233 | "source": [
|
234 | 234 | "The following class performs a simple augmentation by randomly-flipping an image.\n",
|
235 |
| - "See the [image augmentation tutorial](https://www.tensorflow.org/tutorials/images/data_augmentation) for more on image augmentation.\n" |
| 235 | + "Go to the [Image augmentation](data_augmentation.ipynb) tutorial to learn more.\n" |
236 | 236 | ]
|
237 | 237 | },
|
238 | 238 | {
|
|
291 | 291 | "id": "Xa3gMAE_9qNa"
|
292 | 292 | },
|
293 | 293 | "source": [
|
294 |
| - "Take a look at an image example and it's correponding mask from the dataset." |
| 294 | + "Visualize an image example and its corresponding mask from the dataset." |
295 | 295 | ]
|
296 | 296 | },
|
297 | 297 | {
|
|
335 | 335 | },
|
336 | 336 | "source": [
|
337 | 337 | "## Define the model\n",
|
338 |
| - "The model being used here is a modified [U-Net](https://arxiv.org/abs/1505.04597). A U-Net consists of an encoder (downsampler) and decoder (upsampler). In-order to learn robust features and reduce the number of trainable parameters, you will use a pretrained model - MobileNetV2 - as the encoder. For the decoder, you will use the upsample block, which is already implemented in the [Pix2pix tutorial](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py) tutorial in the TensorFlow Examples repo.\n" |
| 338 | + "The model being used here is a modified [U-Net](https://arxiv.org/abs/1505.04597). A U-Net consists of an encoder (downsampler) and decoder (upsampler). In-order to learn robust features and reduce the number of trainable parameters, you will use a pretrained model - MobileNetV2 - as the encoder. For the decoder, you will use the upsample block, which is already implemented in the [pix2pix](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py) example in the TensorFlow Examples repo. (Check out the [pix2pix: Image-to-image translation with a conditional GAN](../generative/pix2pix.ipynb) tutorial in a notebook.)\n" |
339 | 339 | ]
|
340 | 340 | },
|
341 | 341 | {
|
|
436 | 436 | "id": "LRsjdZuEnZfA"
|
437 | 437 | },
|
438 | 438 | "source": [
|
439 |
| - "Note that on the number of filters on the last layer is set to the number of `output_channels`. This will be one output channel per class." |
| 439 | + "Note that the number of filters on the last layer is set to the number of `output_channels`. This will be one output channel per class." |
440 | 440 | ]
|
441 | 441 | },
|
442 | 442 | {
|
|
449 | 449 | "\n",
|
450 | 450 | "Now, all that is left to do is to compile and train the model. \n",
|
451 | 451 | "\n",
|
452 |
| - "SInce this is a multiclass classification problem a `CategoricalCrossentropy` with `from_logits=True` is the standard loss function. Use `losses.SparseCategoricalCrossentropy(from_logits=True)` since the labels are scalar integers instead of vectors of scores for each pixel of every class. \n", |
| 452 | + "Since this is a multiclass classification problem, use the `tf.keras.losses.CategoricalCrossentropy` loss function with the `from_logits` argument set to `True`, since the labels are scalar integers instead of vectors of scores for each pixel of every class. \n", |
453 | 453 | "\n",
|
454 | 454 | "When running inference, the label assigned to the pixel is the channel with the highest value. This is what the `create_mask` function is doing."
|
455 | 455 | ]
|
|
496 | 496 | "id": "Tc3MiEO2twLS"
|
497 | 497 | },
|
498 | 498 | "source": [
|
499 |
| - "Try out the model to see what it predicts before training." |
| 499 | + "Try out the model to check what it predicts before training." |
500 | 500 | ]
|
501 | 501 | },
|
502 | 502 | {
|
|
622 | 622 | "id": "7BVXldSo-0mW"
|
623 | 623 | },
|
624 | 624 | "source": [
|
625 |
| - "Now make some predictions. In the interest of saving time, the number of epochs was kept small, but you may set this higher to achieve more accurate results." |
| 625 | + "Now, make some predictions. In the interest of saving time, the number of epochs was kept small, but you may set this higher to achieve more accurate results." |
626 | 626 | ]
|
627 | 627 | },
|
628 | 628 | {
|
|
651 | 651 | "id": "eqtFPqqu2kxP"
|
652 | 652 | },
|
653 | 653 | "source": [
|
654 |
| - "Semantic segmentation datasets can be highly imbalanced meaning that particular class pixels can be present more inside images than that of other classes. Since segmentation problems can be treated as per-pixel classification problems, you can deal with the imbalance problem by weighing the loss function to account for this. It's a simple and elegant way to deal with this problem. See the [imbalanced classes tutorial](https://www.tensorflow.org/tutorials/structured_data/imbalanced_data).\n", |
| 654 | + "Semantic segmentation datasets can be highly imbalanced meaning that particular class pixels can be present more inside images than that of other classes. Since segmentation problems can be treated as per-pixel classification problems, you can deal with the imbalance problem by weighing the loss function to account for this. It's a simple and elegant way to deal with this problem. Refer to the [Classification on imbalanced data](../structured_data/imbalanced_data.ipynb) tutorial to learn more.\n", |
655 | 655 | "\n",
|
656 | 656 | "To [avoid ambiguity](https://github.com/keras-team/keras/issues/3653#issuecomment-243939748), `Model.fit` does not support the `class_weight` argument for inputs with 3+ dimensions."
|
657 | 657 | ]
|
|
679 | 679 | "id": "brbhYODCsvbe"
|
680 | 680 | },
|
681 | 681 | "source": [
|
682 |
| - "So in this case you need to implement the weighting yourself. You'll do this using sample weights: In addition to `(data, label)` pairs, `Model.fit` also accepts `(data, label, sample_weight)` triples.\n", |
| 682 | + "So, in this case you need to implement the weighting yourself. You'll do this using sample weights: In addition to `(data, label)` pairs, `Model.fit` also accepts `(data, label, sample_weight)` triples.\n", |
683 | 683 | "\n",
|
684 | 684 | "`Model.fit` propagates the `sample_weight` to the losses and metrics, which also accept a `sample_weight` argument. The sample weight is multiplied by the sample's value before the reduction step. For example:"
|
685 | 685 | ]
|
|
798 | 798 | },
|
799 | 799 | "source": [
|
800 | 800 | "## Next steps\n",
|
801 |
| - "Now that you have an understanding of what image segmentation is and how it works, you can try this tutorial out with different intermediate layer outputs, or even different pretrained model. You may also challenge yourself by trying out the [Carvana](https://www.kaggle.com/c/carvana-image-masking-challenge/overview) image masking challenge hosted on Kaggle.\n", |
| 801 | + "Now that you have an understanding of what image segmentation is and how it works, you can try this tutorial out with different intermediate layer outputs, or even different pretrained models. You may also challenge yourself by trying out the [Carvana](https://www.kaggle.com/c/carvana-image-masking-challenge/overview) image masking challenge hosted on Kaggle.\n", |
802 | 802 | "\n",
|
803 | 803 | "You may also want to see the [Tensorflow Object Detection API](https://github.com/tensorflow/models/blob/master/research/object_detection/README.md) for another model you can retrain on your own data. Pretrained models are available on [TensorFlow Hub](https://www.tensorflow.org/hub/tutorials/tf2_object_detection#optional)"
|
804 | 804 | ]
|
|
0 commit comments