Skip to content

Commit 94fd565

Browse files
Merge pull request #2300 from kinarr:image-segmentation-tutorial-keras-3-update
PiperOrigin-RevId: 624158569
2 parents ec3a1b3 + 2f4e5fa commit 94fd565

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

site/en/tutorials/images/segmentation.ipynb

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,10 @@
9797
},
9898
"outputs": [],
9999
"source": [
100-
"!pip install git+https://github.com/tensorflow/examples.git"
100+
"!pip install git+https://github.com/tensorflow/examples.git\n",
101+
"!pip install -U keras\n",
102+
"!pip install -q tensorflow_datasets\n",
103+
"!pip install -q -U tensorflow-text tensorflow"
101104
]
102105
},
103106
{
@@ -108,8 +111,9 @@
108111
},
109112
"outputs": [],
110113
"source": [
111-
"import tensorflow as tf\n",
114+
"import numpy as np\n",
112115
"\n",
116+
"import tensorflow as tf\n",
113117
"import tensorflow_datasets as tfds"
114118
]
115119
},
@@ -252,7 +256,7 @@
252256
" # both use the same seed, so they'll make the same random changes.\n",
253257
" self.augment_inputs = tf.keras.layers.RandomFlip(mode=\"horizontal\", seed=seed)\n",
254258
" self.augment_labels = tf.keras.layers.RandomFlip(mode=\"horizontal\", seed=seed)\n",
255-
" \n",
259+
"\n",
256260
" def call(self, inputs, labels):\n",
257261
" inputs = self.augment_inputs(inputs)\n",
258262
" labels = self.augment_labels(labels)\n",
@@ -450,7 +454,7 @@
450454
"source": [
451455
"## Train the model\n",
452456
"\n",
453-
"Now, all that is left to do is to compile and train the model. \n",
457+
"Now, all that is left to do is to compile and train the model.\n",
454458
"\n",
455459
"Since this is a multiclass classification problem, use the `tf.keras.losses.SparseCategoricalCrossentropy` loss function with the `from_logits` argument set to `True`, since the labels are scalar integers instead of vectors of scores for each pixel of every class.\n",
456460
"\n",
@@ -490,7 +494,7 @@
490494
},
491495
"outputs": [],
492496
"source": [
493-
"tf.keras.utils.plot_model(model, show_shapes=True)"
497+
"tf.keras.utils.plot_model(model, show_shapes=True, expand_nested=True, dpi=64)"
494498
]
495499
},
496500
{
@@ -695,12 +699,14 @@
695699
},
696700
"outputs": [],
697701
"source": [
698-
"label = [0,0]\n",
699-
"prediction = [[-3., 0], [-3, 0]] \n",
700-
"sample_weight = [1, 10] \n",
702+
"label = np.array([0,0])\n",
703+
"prediction = np.array([[-3., 0], [-3, 0]])\n",
704+
"sample_weight = [1, 10]\n",
701705
"\n",
702-
"loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True,\n",
703-
" reduction=tf.keras.losses.Reduction.NONE)\n",
706+
"loss = tf.keras.losses.SparseCategoricalCrossentropy(\n",
707+
" from_logits=True,\n",
708+
" reduction=tf.keras.losses.Reduction.NONE\n",
709+
")\n",
704710
"loss(label, prediction, sample_weight).numpy()"
705711
]
706712
},
@@ -729,7 +735,7 @@
729735
" class_weights = tf.constant([2.0, 2.0, 1.0])\n",
730736
" class_weights = class_weights/tf.reduce_sum(class_weights)\n",
731737
"\n",
732-
" # Create an image of `sample_weights` by using the label at each pixel as an \n",
738+
" # Create an image of `sample_weights` by using the label at each pixel as an\n",
733739
" # index into the `class weights` .\n",
734740
" sample_weights = tf.gather(class_weights, indices=tf.cast(label, tf.int32))\n",
735741
"\n",
@@ -811,7 +817,6 @@
811817
"metadata": {
812818
"accelerator": "GPU",
813819
"colab": {
814-
"collapsed_sections": [],
815820
"name": "segmentation.ipynb",
816821
"toc_visible": true
817822
},

0 commit comments

Comments
 (0)