|
| 1 | +--- |
| 2 | +title: Distributed training with Keras 3 |
| 3 | +author: '[Qianli Zhu](https://github.com/qlzh727)' |
| 4 | +date-created: 2023/11/07 |
| 5 | +last-modified: 2023/11/07 |
| 6 | +description: Complete guide to the distribution API for multi-backend Keras. |
| 7 | +accelerator: GPU |
| 8 | +output: rmarkdown::html_vignette |
| 9 | +knit: ({source(here::here("tools/knit.R")); knit_vignette}) |
| 10 | +tether: ~/github/keras-team/keras-io/guides/distribution.py |
| 11 | +--- |
| 12 | + |
| 13 | +## Introduction |
| 14 | + |
| 15 | +The Keras distribution API is a new interface designed to facilitate |
| 16 | +distributed deep learning across a variety of backends like JAX, TensorFlow and |
| 17 | +PyTorch. This powerful API introduces a suite of tools enabling data and model |
| 18 | +parallelism, allowing for efficient scaling of deep learning models on multiple |
| 19 | +accelerators and hosts. Whether leveraging the power of GPUs or TPUs, the API |
| 20 | +provides a streamlined approach to initializing distributed environments, |
| 21 | +defining device meshes, and orchestrating the layout of tensors across |
| 22 | +computational resources. Through classes like `DataParallel` and |
| 23 | +`ModelParallel`, it abstracts the complexity involved in parallel computation, |
| 24 | +making it easier for developers to accelerate their machine learning |
| 25 | +workflows. |
| 26 | + |
| 27 | +## How it works |
| 28 | + |
| 29 | +The Keras distribution API provides a global programming model that allows |
| 30 | +developers to compose applications that operate on tensors in a global context |
| 31 | +(as if working with a single device) while |
| 32 | +automatically managing distribution across many devices. The API leverages the |
| 33 | +underlying framework (e.g. JAX) to distribute the program and tensors according to the |
| 34 | +sharding directives through a procedure called single program, multiple data |
| 35 | +(SPMD) expansion. |
| 36 | + |
| 37 | +By decoupling the application from sharding directives, the API enables running |
| 38 | +the same application on a single device, multiple devices, or even multiple |
| 39 | +clients, while preserving its global semantics. |
| 40 | + |
| 41 | +## Setup |
| 42 | + |
| 43 | +```python |
| 44 | +import os |
| 45 | + |
| 46 | +# The distribution API is only implemented for the JAX backend for now. |
| 47 | +os.environ["KERAS_BACKEND"] = "jax" |
| 48 | + |
| 49 | +import keras |
| 50 | +from keras import layers |
| 51 | +import jax |
| 52 | +import numpy as np |
| 53 | +from tensorflow import data as tf_data # For dataset input. |
| 54 | +``` |
| 55 | + |
| 56 | +## `DeviceMesh` and `TensorLayout` |
| 57 | + |
| 58 | +The `keras.distribution.DeviceMesh` class in Keras distribution API represents a cluster of |
| 59 | +computational devices configured for distributed computation. It aligns with |
| 60 | +similar concepts in [`jax.sharding.Mesh`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.Mesh) and |
| 61 | +[`tf.dtensor.Mesh`](https://www.tensorflow.org/api_docs/python/tf/experimental/dtensor/Mesh), |
| 62 | +where it's used to map the physical devices to a logical mesh structure. |
| 63 | + |
| 64 | +The `TensorLayout` class then specifies how tensors are distributed across the |
| 65 | +`DeviceMesh`, detailing the sharding of tensors along specified axes that |
| 66 | +correspond to the names of the axes in the `DeviceMesh`. |
| 67 | + |
| 68 | +You can find more detailed concept explainers in the |
| 69 | +[TensorFlow DTensor guide](https://www.tensorflow.org/guide/dtensor_overview#dtensors_model_of_distributed_tensors). |
| 70 | + |
| 71 | +```python |
| 72 | +# Retrieve the local available gpu devices. |
| 73 | +devices = jax.devices("gpu") # Assume it has 8 local GPUs. |
| 74 | + |
| 75 | +# Define a 2x4 device mesh with data and model parallel axes |
| 76 | +mesh = keras.distribution.DeviceMesh( |
| 77 | + shape=(2, 4), axis_names=["data", "model"], devices=devices |
| 78 | +) |
| 79 | + |
| 80 | +# A 2D layout, which describes how a tensor is distributed across the |
| 81 | +# mesh. The layout can be visualized as a 2D grid with "model" as rows and |
| 82 | +# "data" as columns, and it is a [4, 2] grid when it mapped to the physical |
| 83 | +# devices on the mesh. |
| 84 | +layout_2d = keras.distribution.TensorLayout(axes=("model", "data"), device_mesh=mesh) |
| 85 | + |
| 86 | +# A 4D layout which could be used for data parallel of a image input. |
| 87 | +replicated_layout_4d = keras.distribution.TensorLayout( |
| 88 | + axes=("data", None, None, None), device_mesh=mesh |
| 89 | +) |
| 90 | +``` |
| 91 | + |
| 92 | +## Distribution |
| 93 | + |
| 94 | +The `Distribution` class in Keras serves as a foundational abstract class designed |
| 95 | +for developing custom distribution strategies. It encapsulates the core logic |
| 96 | +needed to distribute a model's variables, input data, and intermediate |
| 97 | +computations across a device mesh. As an end user, you won't have to interact |
| 98 | +directly with this class, but its subclasses like `DataParallel` or |
| 99 | +`ModelParallel`. |
| 100 | + |
| 101 | +## DataParallel |
| 102 | + |
| 103 | +The `DataParallel` class in the Keras distribution API is designed for the |
| 104 | +data parallelism strategy in distributed training, where the model weights are |
| 105 | +replicated across all devices in the `DeviceMesh`, and each device processes a |
| 106 | +portion of the input data. |
| 107 | + |
| 108 | +Here is a sample usage of this class. |
| 109 | + |
| 110 | +```python |
| 111 | +# Create DataParallel with list of devices. |
| 112 | +# As a shortcut, the devices can be skipped, |
| 113 | +# and Keras will detect all local available devices. |
| 114 | +# E.g. data_parallel = DataParallel() |
| 115 | +data_parallel = keras.distribution.DataParallel(devices=devices) |
| 116 | + |
| 117 | +# Or you can choose to create DataParallel with a 1D `DeviceMesh`. |
| 118 | +mesh_1d = keras.distribution.DeviceMesh( |
| 119 | + shape=(8,), axis_names=["data"], devices=devices |
| 120 | +) |
| 121 | +data_parallel = keras.distribution.DataParallel(device_mesh=mesh_1d) |
| 122 | + |
| 123 | +inputs = np.random.normal(size=(128, 28, 28, 1)) |
| 124 | +labels = np.random.normal(size=(128, 10)) |
| 125 | +dataset = tf_data.Dataset.from_tensor_slices((inputs, labels)).batch(16) |
| 126 | + |
| 127 | +# Set the global distribution. |
| 128 | +keras.distribution.set_distribution(data_parallel) |
| 129 | + |
| 130 | +# Note that all the model weights from here on are replicated to |
| 131 | +# all the devices of the `DeviceMesh`. This includes the RNG |
| 132 | +# state, optimizer states, metrics, etc. The dataset fed into `model.fit` or |
| 133 | +# `model.evaluate` will be split evenly on the batch dimension, and sent to |
| 134 | +# all the devices. You don't have to do any manual aggregration of losses, |
| 135 | +# since all the computation happens in a global context. |
| 136 | +inputs = layers.Input(shape=(28, 28, 1)) |
| 137 | +y = layers.Flatten()(inputs) |
| 138 | +y = layers.Dense(units=200, use_bias=False, activation="relu")(y) |
| 139 | +y = layers.Dropout(0.4)(y) |
| 140 | +y = layers.Dense(units=10, activation="softmax")(y) |
| 141 | +model = keras.Model(inputs=inputs, outputs=y) |
| 142 | + |
| 143 | +model.compile(loss="mse") |
| 144 | +model.fit(dataset, epochs=3) |
| 145 | +model.evaluate(dataset) |
| 146 | +``` |
| 147 | + |
| 148 | +## `ModelParallel` and `LayoutMap` |
| 149 | + |
| 150 | +`ModelParallel` will be mostly useful when model weights are too large to fit |
| 151 | +on a single accelerator. This setting allows you to spit your model weights or |
| 152 | +activation tensors across all the devices on the `DeviceMesh`, and enable the |
| 153 | +horizontal scaling for the large models. |
| 154 | + |
| 155 | +Unlike the `DataParallel` model where all weights are fully replicated, |
| 156 | +the weights layout under `ModelParallel` usually need some customization for |
| 157 | +best performances. We introduce `LayoutMap` to let you specify the |
| 158 | +`TensorLayout` for any weights and intermediate tensors from global perspective. |
| 159 | + |
| 160 | +`LayoutMap` is a dict-like object that maps a string to `TensorLayout` |
| 161 | +instances. It behaves differently from a normal Python dict in that the string |
| 162 | +key is treated as a regex when retrieving the value. The class allows you to |
| 163 | +define the naming schema of `TensorLayout` and then retrieve the corresponding |
| 164 | +`TensorLayout` instance. Typically, the key used to query |
| 165 | +is the `variable.path` attribute, which is the identifier of the variable. |
| 166 | +As a shortcut, a tuple or list of axis |
| 167 | +names is also allowed when inserting a value, and it will be converted to |
| 168 | +`TensorLayout`. |
| 169 | + |
| 170 | +The `LayoutMap` can also optionally contain a `DeviceMesh` to populate the |
| 171 | +`TensorLayout.device_mesh` if it is not set. When retrieving a layout with a |
| 172 | +key, and if there isn't an exact match, all existing keys in the layout map will |
| 173 | +be treated as regex and matched against the input key again. If there are |
| 174 | +multiple matches, a `ValueError` is raised. If no matches are found, `None` is |
| 175 | +returned. |
| 176 | + |
| 177 | +```python |
| 178 | +mesh_2d = keras.distribution.DeviceMesh( |
| 179 | + shape=(2, 4), axis_names=["data", "model"], devices=devices |
| 180 | +) |
| 181 | +layout_map = keras.distribution.LayoutMap(mesh_2d) |
| 182 | +# The rule below means that for any weights that match with d1/kernel, it |
| 183 | +# will be sharded with model dimensions (4 devices), same for the d1/bias. |
| 184 | +# All other weights will be fully replicated. |
| 185 | +layout_map["d1/kernel"] = (None, "model") |
| 186 | +layout_map["d1/bias"] = ("model",) |
| 187 | + |
| 188 | +# You can also set the layout for the layer output like |
| 189 | +layout_map["d2/output"] = ("data", None) |
| 190 | + |
| 191 | +model_parallel = keras.distribution.ModelParallel( |
| 192 | + mesh_2d, layout_map, batch_dim_name="data" |
| 193 | +) |
| 194 | + |
| 195 | +keras.distribution.set_distribution(model_parallel) |
| 196 | + |
| 197 | +inputs = layers.Input(shape=(28, 28, 1)) |
| 198 | +y = layers.Flatten()(inputs) |
| 199 | +y = layers.Dense(units=200, use_bias=False, activation="relu", name="d1")(y) |
| 200 | +y = layers.Dropout(0.4)(y) |
| 201 | +y = layers.Dense(units=10, activation="softmax", name="d2")(y) |
| 202 | +model = keras.Model(inputs=inputs, outputs=y) |
| 203 | + |
| 204 | +# The data will be sharded across the "data" dimension of the method, which |
| 205 | +# has 2 devices. |
| 206 | +model.compile(loss="mse") |
| 207 | +model.fit(dataset, epochs=3) |
| 208 | +model.evaluate(dataset) |
| 209 | +``` |
| 210 | + |
| 211 | +It is also easy to change the mesh structure to tune the computation between |
| 212 | +more data parallel or model parallel. You can do this by adjusting the shape of |
| 213 | +the mesh. And no changes are needed for any other code. |
| 214 | + |
| 215 | +```python |
| 216 | +full_data_parallel_mesh = keras.distribution.DeviceMesh( |
| 217 | + shape=(8, 1), axis_names=["data", "model"], devices=devices |
| 218 | +) |
| 219 | +more_data_parallel_mesh = keras.distribution.DeviceMesh( |
| 220 | + shape=(4, 2), axis_names=["data", "model"], devices=devices |
| 221 | +) |
| 222 | +more_model_parallel_mesh = keras.distribution.DeviceMesh( |
| 223 | + shape=(2, 4), axis_names=["data", "model"], devices=devices |
| 224 | +) |
| 225 | +full_model_parallel_mesh = keras.distribution.DeviceMesh( |
| 226 | + shape=(1, 8), axis_names=["data", "model"], devices=devices |
| 227 | +) |
| 228 | +``` |
| 229 | + |
| 230 | +### Further reading |
| 231 | + |
| 232 | +1. [JAX Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) |
| 233 | +2. [JAX sharding module](https://jax.readthedocs.io/en/latest/jax.sharding.html) |
| 234 | +3. [TensorFlow Distributed training with DTensors](https://www.tensorflow.org/tutorials/distribute/dtensor_ml_tutorial) |
| 235 | +4. [TensorFlow DTensor concepts](https://www.tensorflow.org/guide/dtensor_overview) |
| 236 | +5. [Using DTensors with tf.keras](https://www.tensorflow.org/tutorials/distribute/dtensor_keras_tutorial) |
| 237 | + |
0 commit comments