|
39 | 39 | "source": [
|
40 | 40 | "# Semantic Segmentation with Model Garden\n",
|
41 | 41 | "\n",
|
42 |
| - "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", |
43 |
| - " \u003ctd\u003e\n", |
44 |
| - " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/vision/semantic_segmentation\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", |
45 |
| - " \u003c/td\u003e\n", |
46 |
| - " \u003ctd\u003e\n", |
47 |
| - " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/vision/semantic_segmentation.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", |
48 |
| - " \u003c/td\u003e\n", |
49 |
| - " \u003ctd\u003e\n", |
50 |
| - " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/vision/semantic_segmentation.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n", |
51 |
| - " \u003c/td\u003e\n", |
52 |
| - " \u003ctd\u003e\n", |
53 |
| - " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/vision/semantic_segmentation.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n", |
54 |
| - " \u003c/td\u003e\n", |
55 |
| - "\u003c/table\u003e" |
| 42 | + "<table class=\"tfo-notebook-buttons\" align=\"left\">\n", |
| 43 | + " <td>\n", |
| 44 | + " <a target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/vision/semantic_segmentation\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n", |
| 45 | + " </td>\n", |
| 46 | + " <td>\n", |
| 47 | + " <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/vision/semantic_segmentation.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n", |
| 48 | + " </td>\n", |
| 49 | + " <td>\n", |
| 50 | + " <a target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/vision/semantic_segmentation.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View on GitHub</a>\n", |
| 51 | + " </td>\n", |
| 52 | + " <td>\n", |
| 53 | + " <a href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/vision/semantic_segmentation.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n", |
| 54 | + " </td>\n", |
| 55 | + "</table>" |
56 | 56 | ]
|
57 | 57 | },
|
58 | 58 | {
|
|
95 | 95 | },
|
96 | 96 | "outputs": [],
|
97 | 97 | "source": [
|
98 |
| - "!pip install -U -q \"tensorflow\u003e=2.9.2\" \"tf-models-official\"" |
| 98 | + "!pip install -U -q \"tf-models-official\"" |
99 | 99 | ]
|
100 | 100 | },
|
101 | 101 | {
|
|
138 | 138 | "import orbit\n",
|
139 | 139 | "import tensorflow_models as tfm\n",
|
140 | 140 | "from official.vision.data import tfrecord_lib\n",
|
| 141 | + "from official.vision.utils import summary_manager\n", |
141 | 142 | "from official.vision.serving import export_saved_model_lib\n",
|
142 | 143 | "from official.vision.utils.object_detection import visualization_utils\n",
|
143 | 144 | "\n",
|
|
428 | 429 | "exp_config.task.validation_data.output_size = [HEIGHT, WIDTH]\n",
|
429 | 430 | "exp_config.task.validation_data.preserve_aspect_ratio = False\n",
|
430 | 431 | "exp_config.task.validation_data.groundtruth_padded_size = [HEIGHT, WIDTH]\n",
|
431 |
| - "exp_config.task.validation_data.seed = 21 # Reproducable Validation Data" |
| 432 | + "exp_config.task.validation_data.seed = 21 # Reproducable Validation Data\n", |
| 433 | + "exp_config.task.validation_data.resize_eval_groundtruth = True # To enable validation loss" |
432 | 434 | ]
|
433 | 435 | },
|
434 | 436 | {
|
|
540 | 542 | "source": [
|
541 | 543 | "## Create the `Task` object (`tfm.core.base_task.Task`) from the `config_definitions.TaskConfig`.\n",
|
542 | 544 | "\n",
|
543 |
| - "The `Task` object has all the methods necessary for building the dataset, building the model, and running training \u0026 evaluation. These methods are driven by `tfm.core.train_lib.run_experiment`." |
| 545 | + "The `Task` object has all the methods necessary for building the dataset, building the model, and running training & evaluation. These methods are driven by `tfm.core.train_lib.run_experiment`." |
544 | 546 | ]
|
545 | 547 | },
|
546 | 548 | {
|
|
597 | 599 | },
|
598 | 600 | "outputs": [],
|
599 | 601 | "source": [
|
600 |
| - "def display(display_list):\n", |
| 602 | + "def plot_masks(display_list):\n", |
601 | 603 | " plt.figure(figsize=(15, 15))\n",
|
602 | 604 | "\n",
|
603 | 605 | " title = ['Input Image', 'True Mask', 'Predicted Mask']\n",
|
|
636 | 638 | "num_examples = 3\n",
|
637 | 639 | "\n",
|
638 | 640 | "for images, masks in task.build_inputs(exp_config.task.train_data).take(num_examples):\n",
|
639 |
| - " display([images[0], masks['masks'][0]])" |
| 641 | + " plot_masks([images[0], masks['masks'][0]])" |
640 | 642 | ]
|
641 | 643 | },
|
642 | 644 | {
|
|
657 | 659 | },
|
658 | 660 | "outputs": [],
|
659 | 661 | "source": [
|
| 662 | + "\n", |
660 | 663 | "model, eval_logs = tfm.core.train_lib.run_experiment(\n",
|
661 | 664 | " distribution_strategy=distribution_strategy,\n",
|
662 | 665 | " task=task,\n",
|
663 | 666 | " mode='train_and_eval',\n",
|
664 | 667 | " params=exp_config,\n",
|
665 | 668 | " model_dir=model_dir,\n",
|
| 669 | + " eval_summary_manager=summary_manager.maybe_build_eval_summary_manager(\n", |
| 670 | + " params=exp_config, model_dir=model_dir),\n", |
666 | 671 | " run_post_eval=True)"
|
667 | 672 | ]
|
668 | 673 | },
|
|
764 | 769 | " image = tf.cast(image, dtype=tf.uint8)\n",
|
765 | 770 | " mask = tf.image.resize(record['segmentation_mask'], size=[HEIGHT, WIDTH])\n",
|
766 | 771 | " predicted_mask = model_fn(tf.expand_dims(record['image'], axis=0))\n",
|
767 |
| - " display([image, mask, create_mask(predicted_mask['logits'])])" |
| 772 | + " plot_masks([image, mask, create_mask(predicted_mask['logits'])])" |
768 | 773 | ]
|
769 | 774 | }
|
770 | 775 | ],
|
|
0 commit comments