|
69 | 69 | },
|
70 | 70 | "source": [
|
71 | 71 | "## Overview\n",
|
72 |
| - "In this tutoral, you will learn how to use DTensor with Keras.\n", |
| 72 | + "\n", |
| 73 | + "In this tutorial, you will learn how to use DTensors with Keras.\n", |
73 | 74 | "\n",
|
74 | 75 | "Through DTensor integration with Keras, you can reuse your existing Keras layers and models to build and train distributed machine learning models.\n",
|
75 | 76 | "\n",
|
76 | 77 | "You will train a multi-layer classification model with the MNIST data. Setting the layout for subclassing model, Sequential model, and functional model will be demonstrated.\n",
|
77 | 78 | "\n",
|
78 |
| - "This tutoral assumes that you have already read the [DTensor programing guide](/guide/dtensor_overview), and are familiar with basic DTensor concepts like `Mesh` and `Layout`.\n", |
| 79 | + "This tutorial assumes that you have already read the [DTensor programing guide](/guide/dtensor_overview), and are familiar with basic DTensor concepts like `Mesh` and `Layout`.\n", |
79 | 80 | "\n",
|
80 |
| - "This tutoral is based on https://www.tensorflow.org/datasets/keras_example." |
| 81 | + "This tutorial is based on [Training a neural network on MNIST with Keras](https://www.tensorflow.org/datasets/keras_example)." |
81 | 82 | ]
|
82 | 83 | },
|
83 | 84 | {
|
|
88 | 89 | "source": [
|
89 | 90 | "## Setup\n",
|
90 | 91 | "\n",
|
91 |
| - "DTensor is part of TensorFlow 2.9.0 release." |
| 92 | + "DTensor (`tf.experimental.dtensor`) has been part of TensorFlow since the 2.9.0 release.\n", |
| 93 | + "\n", |
| 94 | + "First, install or upgrade TensorFlow and TensorFlow Datasets:" |
92 | 95 | ]
|
93 | 96 | },
|
94 | 97 | {
|
|
99 | 102 | },
|
100 | 103 | "outputs": [],
|
101 | 104 | "source": [
|
102 |
| - "!pip install --quiet --upgrade --pre tensorflow tensorflow-datasets" |
| 105 | + "!pip install --quiet --upgrade tensorflow tensorflow-datasets" |
103 | 106 | ]
|
104 | 107 | },
|
105 | 108 | {
|
|
108 | 111 | "id": "VttBMZngDx8x"
|
109 | 112 | },
|
110 | 113 | "source": [
|
111 |
| - "Next, import `tensorflow` and `tensorflow.experimental.dtensor`, and configure TensorFlow to use 8 virtual CPUs.\n", |
| 114 | + "Next, import `tensorflow` and `dtensor`, and configure TensorFlow to use 8 virtual CPUs.\n", |
112 | 115 | "\n",
|
113 |
| - "Even though this example uses CPUs, DTensor works the same way on CPU, GPU or TPU devices." |
| 116 | + "Even though this example uses virtual CPUs, DTensor works the same way on CPU, GPU or TPU devices." |
114 | 117 | ]
|
115 | 118 | },
|
116 | 119 | {
|
|
176 | 179 | "source": [
|
177 | 180 | "## Creating a Data Parallel Mesh\n",
|
178 | 181 | "\n",
|
179 |
| - "This tutorial demonstrates Data Parallel training. Adapting to Model Parallel training and Spatial Parallel training can be as simple as switching to a different set of `Layout` objects. Refer to [DTensor in-depth ML Tutorial](https://www.tensorflow.org/tutorials/distribute/dtensor_ml_tutorial) for more information on distributed training beyond Data Parallel.\n", |
| 182 | + "This tutorial demonstrates Data Parallel training. Adapting to Model Parallel training and Spatial Parallel training can be as simple as switching to a different set of `Layout` objects. Refer to the [Distributed training with DTensors](dtensor_ml_tutorial.ipynb) tutorial for more information on distributed training beyond Data Parallel.\n", |
180 | 183 | "\n",
|
181 |
| - "Data Parallel training is a commonly used parallel training scheme, also used by for example `tf.distribute.MirroredStrategy`.\n", |
| 184 | + "Data Parallel training is a commonly used parallel training scheme, also used by, for example, `tf.distribute.MirroredStrategy`.\n", |
182 | 185 | "\n",
|
183 |
| - "With DTensor, a Data Parallel training loop uses a `Mesh` that consists of a single 'batch' dimension, where each device runs a replica of the model that receives a shard from the global batch.\n" |
| 186 | + "With DTensor, a Data Parallel training loop uses a `Mesh` that consists of a single 'batch' dimension, where each device runs a replica of the model that receives a shard from the global batch." |
184 | 187 | ]
|
185 | 188 | },
|
186 | 189 | {
|
|
248 | 251 | "\n",
|
249 | 252 | "In order to configure the layout information for your layers' weights, Keras has exposed an extra parameter in the layer constructor for most of the built-in layers.\n",
|
250 | 253 | "\n",
|
251 |
| - "The following example builds a small image classification model with fully replicated weight layout. You can specify layout information `kernel` and `bias` in `tf.keras.layers.Dense` via argument `kernel_layout` and `bias_layout`. Most of the built-in keras layers are ready for explicitly specifying the `Layout` for the layer weights." |
| 254 | + "The following example builds a small image classification model with fully replicated weight layout. You can specify layout information `kernel` and `bias` in `tf.keras.layers.Dense` via arguments `kernel_layout` and `bias_layout`. Most of the built-in keras layers are ready for explicitly specifying the `Layout` for the layer weights." |
252 | 255 | ]
|
253 | 256 | },
|
254 | 257 | {
|
|
315 | 318 | "source": [
|
316 | 319 | "## Load a dataset and build input pipeline\n",
|
317 | 320 | "\n",
|
318 |
| - "Load a MNIST dataset and configure some pre-processing input pipeline for it. The dataset itself is not associated with any DTensor layout information. There are plans to improve DTensor Keras integration with `tf.data` in future TensorFlow releases.\n" |
| 321 | + "Load a MNIST dataset and configure some pre-processing input pipeline for it. The dataset itself is not associated with any DTensor layout information." |
319 | 322 | ]
|
320 | 323 | },
|
321 | 324 | {
|
|
389 | 392 | "source": [
|
390 | 393 | "## Define the training logic for the model\n",
|
391 | 394 | "\n",
|
392 |
| - "Next define the training and evalution logic for the model. \n", |
| 395 | + "Next, define the training and evaluation logic for the model. \n", |
393 | 396 | "\n",
|
394 |
| - "As of TensorFlow 2.9, you have to write a custom-training-loop for a DTensor enabled Keras model. This is to pack the input data with proper layout information, which is not integrated with the standard `tf.keras.Model.fit()` or `tf.keras.Model.eval()` functions from Keras. you will get more `tf.data` support in the upcoming release. " |
| 397 | + "As of TensorFlow 2.9, you have to write a custom-training-loop for a DTensor-enabled Keras model. This is to pack the input data with proper layout information, which is not integrated with the standard `tf.keras.Model.fit()` or `tf.keras.Model.eval()` functions from Keras. you will get more `tf.data` support in the upcoming release. " |
395 | 398 | ]
|
396 | 399 | },
|
397 | 400 | {
|
|
467 | 470 | "id": "9Eb-qIJGrxB9"
|
468 | 471 | },
|
469 | 472 | "source": [
|
470 |
| - "## Metrics and Optimizers\n", |
| 473 | + "## Metrics and optimizers\n", |
471 | 474 | "\n",
|
472 | 475 | "When using DTensor API with Keras `Metric` and `Optimizer`, you will need to provide the extra mesh information, so that any internal state variables and tensors can work with variables in the model.\n",
|
473 | 476 | "\n",
|
|
497 | 500 | "source": [
|
498 | 501 | "## Train the model\n",
|
499 | 502 | "\n",
|
500 |
| - "The following example shards the data from input pipeline on the batch dimension, and train with the model, which has fully replicated weights. \n", |
| 503 | + "The following example demonstrates how to shard the data from input pipeline on the batch dimension, and train with the model, which has fully replicated weights. \n", |
501 | 504 | "\n",
|
502 |
| - "With 3 epochs, the model should achieve about 97% of accuracy." |
| 505 | + "After 3 epochs, the model should achieve about 97% of accuracy:" |
503 | 506 | ]
|
504 | 507 | },
|
505 | 508 | {
|
|
561 | 564 | "\n",
|
562 | 565 | "Often you have models that work well for your use case. Specifying `Layout` information to each individual layer within the model will be a large amount of work requiring a lot of edits.\n",
|
563 | 566 | "\n",
|
564 |
| - "To help you easily convert your existing Keras model to work with DTensor API you can use the new `dtensor.LayoutMap` API that allow you to specify the `Layout` from a global point of view.\n", |
| 567 | + "To help you easily convert your existing Keras model to work with DTensor API you can use the new `tf.keras.dtensor.experimental.LayoutMap` API that allow you to specify the `Layout` from a global point of view.\n", |
565 | 568 | "\n",
|
566 | 569 | "First, you need to create a `LayoutMap` instance, which is a dictionary-like object that contains all the `Layout` you would like to specify for your model weights.\n",
|
567 | 570 | "\n",
|
568 | 571 | "`LayoutMap` needs a `Mesh` instance at init, which can be used to provide default replicated `Layout` for any weights that doesn't have Layout configured. In case you would like all your model weights to be just fully replicated, you can provide empty `LayoutMap`, and the default mesh will be used to create replicated `Layout`.\n",
|
569 | 572 | "\n",
|
570 |
| - "`LayoutMap` uses a string as key and a `Layout` as value. There is a behavior difference between a normal Python dict and this class. The string key will be treated as a regex when retrieving the value" |
| 573 | + "`LayoutMap` uses a string as key and a `Layout` as value. There is a behavior difference between a normal Python dict and this class. The string key will be treated as a regex when retrieving the value." |
571 | 574 | ]
|
572 | 575 | },
|
573 | 576 | {
|
|
616 | 619 | "* `model.feature_2.kernel`\n",
|
617 | 620 | "* `model.feature_2.bias`\n",
|
618 | 621 | "\n",
|
619 |
| - "Note: For Subclassed Models, the attribute name, rather than the `.name` attribute of layer are used as the key to retrieve the Layout from the mapping. This is consistent with the convention followed by `tf.Module` checkpointing. For complex models with more than a few layers, you can [manually inspect checkpoints](https://www.tensorflow.org/guide/checkpoint#manually_inspecting_checkpoints) to see the attribute mappings. \n", |
| 622 | + "Note: For subclassed Models, the attribute name, rather than the `.name` attribute of the layer, is used as the key to retrieve the Layout from the mapping. This is consistent with the convention followed by `tf.Module` checkpointing. For complex models with more than a few layers, you can [manually inspect checkpoints](https://www.tensorflow.org/guide/checkpoint#manually_inspecting_checkpoints) to view the attribute mappings. \n", |
620 | 623 | "\n",
|
621 |
| - "Now define the following `LayoutMap` and apply it to the model." |
| 624 | + "Now define the following `LayoutMap` and apply it to the model:" |
622 | 625 | ]
|
623 | 626 | },
|
624 | 627 | {
|
|
644 | 647 | "id": "M32HcSp_PyWs"
|
645 | 648 | },
|
646 | 649 | "source": [
|
647 |
| - "The model weights are created on the first call, so call the model with a DTensor input and confirm the weights have the expected layouts." |
| 650 | + "The model weights are created on the first call, so call the model with a DTensor input and confirm the weights have the expected layouts:" |
648 | 651 | ]
|
649 | 652 | },
|
650 | 653 | {
|
|
686 | 689 | "id": "6zzvTqAR2Teu"
|
687 | 690 | },
|
688 | 691 | "source": [
|
689 |
| - "For keras functional and sequential models, you can use `LayoutMap` as well.\n", |
| 692 | + "For Keras Functional and Sequential models, you can use `tf.keras.dtensor.experimental.LayoutMap` as well.\n", |
690 | 693 | "\n",
|
691 |
| - "Note: For functional and sequential models, the mappings are slightly different. The layers in the model don't have a public attribute attached to the model (though you can access them via `model.layers` as a list). Use the string name as the key in this case. The string name is guaranteed to be unique within a model." |
| 694 | + "Note: For Functional and Sequential models, the mappings are slightly different. The layers in the model don't have a public attribute attached to the model (though you can access them via `Model.layers` as a list). Use the string name as the key in this case. The string name is guaranteed to be unique within a model." |
692 | 695 | ]
|
693 | 696 | },
|
694 | 697 | {
|
|
745 | 748 | "metadata": {
|
746 | 749 | "colab": {
|
747 | 750 | "name": "dtensor_keras_tutorial.ipynb",
|
748 |
| - "toc_visible": true |
| 751 | + "toc_visible": true |
749 | 752 | },
|
750 | 753 | "kernelspec": {
|
751 | 754 | "display_name": "Python 3",
|
|
0 commit comments