Skip to content

Commit ad3f99f

Browse files
committed
checkin new example tethers
1 parent 3592976 commit ad3f99f

File tree

7 files changed

+1182
-43
lines changed

7 files changed

+1182
-43
lines changed
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
---
2+
title: Distributed training with Keras 3
3+
author: '[Qianli Zhu](https://github.com/qlzh727)'
4+
date-created: 2023/11/07
5+
last-modified: 2023/11/07
6+
description: Complete guide to the distribution API for multi-backend Keras.
7+
accelerator: GPU
8+
output: rmarkdown::html_vignette
9+
knit: ({source(here::here("tools/knit.R")); knit_vignette})
10+
tether: ~/github/keras-team/keras-io/guides/distribution.py
11+
---
12+
13+
## Introduction
14+
15+
The Keras distribution API is a new interface designed to facilitate
16+
distributed deep learning across a variety of backends like JAX, TensorFlow and
17+
PyTorch. This powerful API introduces a suite of tools enabling data and model
18+
parallelism, allowing for efficient scaling of deep learning models on multiple
19+
accelerators and hosts. Whether leveraging the power of GPUs or TPUs, the API
20+
provides a streamlined approach to initializing distributed environments,
21+
defining device meshes, and orchestrating the layout of tensors across
22+
computational resources. Through classes like `DataParallel` and
23+
`ModelParallel`, it abstracts the complexity involved in parallel computation,
24+
making it easier for developers to accelerate their machine learning
25+
workflows.
26+
27+
## How it works
28+
29+
The Keras distribution API provides a global programming model that allows
30+
developers to compose applications that operate on tensors in a global context
31+
(as if working with a single device) while
32+
automatically managing distribution across many devices. The API leverages the
33+
underlying framework (e.g. JAX) to distribute the program and tensors according to the
34+
sharding directives through a procedure called single program, multiple data
35+
(SPMD) expansion.
36+
37+
By decoupling the application from sharding directives, the API enables running
38+
the same application on a single device, multiple devices, or even multiple
39+
clients, while preserving its global semantics.
40+
41+
## Setup
42+
43+
```python
44+
import os
45+
46+
# The distribution API is only implemented for the JAX backend for now.
47+
os.environ["KERAS_BACKEND"] = "jax"
48+
49+
import keras
50+
from keras import layers
51+
import jax
52+
import numpy as np
53+
from tensorflow import data as tf_data # For dataset input.
54+
```
55+
56+
## `DeviceMesh` and `TensorLayout`
57+
58+
The `keras.distribution.DeviceMesh` class in Keras distribution API represents a cluster of
59+
computational devices configured for distributed computation. It aligns with
60+
similar concepts in [`jax.sharding.Mesh`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.Mesh) and
61+
[`tf.dtensor.Mesh`](https://www.tensorflow.org/api_docs/python/tf/experimental/dtensor/Mesh),
62+
where it's used to map the physical devices to a logical mesh structure.
63+
64+
The `TensorLayout` class then specifies how tensors are distributed across the
65+
`DeviceMesh`, detailing the sharding of tensors along specified axes that
66+
correspond to the names of the axes in the `DeviceMesh`.
67+
68+
You can find more detailed concept explainers in the
69+
[TensorFlow DTensor guide](https://www.tensorflow.org/guide/dtensor_overview#dtensors_model_of_distributed_tensors).
70+
71+
```python
72+
# Retrieve the local available gpu devices.
73+
devices = jax.devices("gpu") # Assume it has 8 local GPUs.
74+
75+
# Define a 2x4 device mesh with data and model parallel axes
76+
mesh = keras.distribution.DeviceMesh(
77+
shape=(2, 4), axis_names=["data", "model"], devices=devices
78+
)
79+
80+
# A 2D layout, which describes how a tensor is distributed across the
81+
# mesh. The layout can be visualized as a 2D grid with "model" as rows and
82+
# "data" as columns, and it is a [4, 2] grid when it mapped to the physical
83+
# devices on the mesh.
84+
layout_2d = keras.distribution.TensorLayout(axes=("model", "data"), device_mesh=mesh)
85+
86+
# A 4D layout which could be used for data parallel of a image input.
87+
replicated_layout_4d = keras.distribution.TensorLayout(
88+
axes=("data", None, None, None), device_mesh=mesh
89+
)
90+
```
91+
92+
## Distribution
93+
94+
The `Distribution` class in Keras serves as a foundational abstract class designed
95+
for developing custom distribution strategies. It encapsulates the core logic
96+
needed to distribute a model's variables, input data, and intermediate
97+
computations across a device mesh. As an end user, you won't have to interact
98+
directly with this class, but its subclasses like `DataParallel` or
99+
`ModelParallel`.
100+
101+
## DataParallel
102+
103+
The `DataParallel` class in the Keras distribution API is designed for the
104+
data parallelism strategy in distributed training, where the model weights are
105+
replicated across all devices in the `DeviceMesh`, and each device processes a
106+
portion of the input data.
107+
108+
Here is a sample usage of this class.
109+
110+
```python
111+
# Create DataParallel with list of devices.
112+
# As a shortcut, the devices can be skipped,
113+
# and Keras will detect all local available devices.
114+
# E.g. data_parallel = DataParallel()
115+
data_parallel = keras.distribution.DataParallel(devices=devices)
116+
117+
# Or you can choose to create DataParallel with a 1D `DeviceMesh`.
118+
mesh_1d = keras.distribution.DeviceMesh(
119+
shape=(8,), axis_names=["data"], devices=devices
120+
)
121+
data_parallel = keras.distribution.DataParallel(device_mesh=mesh_1d)
122+
123+
inputs = np.random.normal(size=(128, 28, 28, 1))
124+
labels = np.random.normal(size=(128, 10))
125+
dataset = tf_data.Dataset.from_tensor_slices((inputs, labels)).batch(16)
126+
127+
# Set the global distribution.
128+
keras.distribution.set_distribution(data_parallel)
129+
130+
# Note that all the model weights from here on are replicated to
131+
# all the devices of the `DeviceMesh`. This includes the RNG
132+
# state, optimizer states, metrics, etc. The dataset fed into `model.fit` or
133+
# `model.evaluate` will be split evenly on the batch dimension, and sent to
134+
# all the devices. You don't have to do any manual aggregration of losses,
135+
# since all the computation happens in a global context.
136+
inputs = layers.Input(shape=(28, 28, 1))
137+
y = layers.Flatten()(inputs)
138+
y = layers.Dense(units=200, use_bias=False, activation="relu")(y)
139+
y = layers.Dropout(0.4)(y)
140+
y = layers.Dense(units=10, activation="softmax")(y)
141+
model = keras.Model(inputs=inputs, outputs=y)
142+
143+
model.compile(loss="mse")
144+
model.fit(dataset, epochs=3)
145+
model.evaluate(dataset)
146+
```
147+
148+
## `ModelParallel` and `LayoutMap`
149+
150+
`ModelParallel` will be mostly useful when model weights are too large to fit
151+
on a single accelerator. This setting allows you to spit your model weights or
152+
activation tensors across all the devices on the `DeviceMesh`, and enable the
153+
horizontal scaling for the large models.
154+
155+
Unlike the `DataParallel` model where all weights are fully replicated,
156+
the weights layout under `ModelParallel` usually need some customization for
157+
best performances. We introduce `LayoutMap` to let you specify the
158+
`TensorLayout` for any weights and intermediate tensors from global perspective.
159+
160+
`LayoutMap` is a dict-like object that maps a string to `TensorLayout`
161+
instances. It behaves differently from a normal Python dict in that the string
162+
key is treated as a regex when retrieving the value. The class allows you to
163+
define the naming schema of `TensorLayout` and then retrieve the corresponding
164+
`TensorLayout` instance. Typically, the key used to query
165+
is the `variable.path` attribute, which is the identifier of the variable.
166+
As a shortcut, a tuple or list of axis
167+
names is also allowed when inserting a value, and it will be converted to
168+
`TensorLayout`.
169+
170+
The `LayoutMap` can also optionally contain a `DeviceMesh` to populate the
171+
`TensorLayout.device_mesh` if it is not set. When retrieving a layout with a
172+
key, and if there isn't an exact match, all existing keys in the layout map will
173+
be treated as regex and matched against the input key again. If there are
174+
multiple matches, a `ValueError` is raised. If no matches are found, `None` is
175+
returned.
176+
177+
```python
178+
mesh_2d = keras.distribution.DeviceMesh(
179+
shape=(2, 4), axis_names=["data", "model"], devices=devices
180+
)
181+
layout_map = keras.distribution.LayoutMap(mesh_2d)
182+
# The rule below means that for any weights that match with d1/kernel, it
183+
# will be sharded with model dimensions (4 devices), same for the d1/bias.
184+
# All other weights will be fully replicated.
185+
layout_map["d1/kernel"] = (None, "model")
186+
layout_map["d1/bias"] = ("model",)
187+
188+
# You can also set the layout for the layer output like
189+
layout_map["d2/output"] = ("data", None)
190+
191+
model_parallel = keras.distribution.ModelParallel(
192+
mesh_2d, layout_map, batch_dim_name="data"
193+
)
194+
195+
keras.distribution.set_distribution(model_parallel)
196+
197+
inputs = layers.Input(shape=(28, 28, 1))
198+
y = layers.Flatten()(inputs)
199+
y = layers.Dense(units=200, use_bias=False, activation="relu", name="d1")(y)
200+
y = layers.Dropout(0.4)(y)
201+
y = layers.Dense(units=10, activation="softmax", name="d2")(y)
202+
model = keras.Model(inputs=inputs, outputs=y)
203+
204+
# The data will be sharded across the "data" dimension of the method, which
205+
# has 2 devices.
206+
model.compile(loss="mse")
207+
model.fit(dataset, epochs=3)
208+
model.evaluate(dataset)
209+
```
210+
211+
It is also easy to change the mesh structure to tune the computation between
212+
more data parallel or model parallel. You can do this by adjusting the shape of
213+
the mesh. And no changes are needed for any other code.
214+
215+
```python
216+
full_data_parallel_mesh = keras.distribution.DeviceMesh(
217+
shape=(8, 1), axis_names=["data", "model"], devices=devices
218+
)
219+
more_data_parallel_mesh = keras.distribution.DeviceMesh(
220+
shape=(4, 2), axis_names=["data", "model"], devices=devices
221+
)
222+
more_model_parallel_mesh = keras.distribution.DeviceMesh(
223+
shape=(2, 4), axis_names=["data", "model"], devices=devices
224+
)
225+
full_model_parallel_mesh = keras.distribution.DeviceMesh(
226+
shape=(1, 8), axis_names=["data", "model"], devices=devices
227+
)
228+
```
229+
230+
### Further reading
231+
232+
1. [JAX Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)
233+
2. [JAX sharding module](https://jax.readthedocs.io/en/latest/jax.sharding.html)
234+
3. [TensorFlow Distributed training with DTensors](https://www.tensorflow.org/tutorials/distribute/dtensor_ml_tutorial)
235+
4. [TensorFlow DTensor concepts](https://www.tensorflow.org/guide/dtensor_overview)
236+
5. [Using DTensors with tf.keras](https://www.tensorflow.org/tutorials/distribute/dtensor_keras_tutorial)
237+

0 commit comments

Comments
 (0)