Skip to content

Commit d9c9691

Browse files
authored
[docs] Model sharding (huggingface#9521)
* flux shard * feedback
1 parent 065ce07 commit d9c9691

File tree

2 files changed

+130
-2
lines changed

2 files changed

+130
-2
lines changed

docs/source/en/_toctree.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
- local: using-diffusers/overview_techniques
5757
title: Overview
5858
- local: training/distributed_inference
59-
title: Distributed inference with multiple GPUs
59+
title: Distributed inference
6060
- local: using-diffusers/merge_loras
6161
title: Merge LoRAs
6262
- local: using-diffusers/scheduler_features

docs/source/en/training/distributed_inference.md

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
1010
specific language governing permissions and limitations under the License.
1111
-->
1212

13-
# Distributed inference with multiple GPUs
13+
# Distributed inference
1414

1515
On distributed setups, you can run inference across multiple GPUs with 🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) or [PyTorch Distributed](https://pytorch.org/tutorials/beginner/dist_overview.html), which is useful for generating with multiple prompts in parallel.
1616

@@ -109,3 +109,131 @@ torchrun run_distributed.py --nproc_per_node=2
109109

110110
> [!TIP]
111111
> You can use `device_map` within a [`DiffusionPipeline`] to distribute its model-level components on multiple devices. Refer to the [Device placement](../tutorials/inference_with_big_models#device-placement) guide to learn more.
112+
113+
## Model sharding
114+
115+
Modern diffusion systems such as [Flux](../api/pipelines/flux) are very large and have multiple models. For example, [Flux.1-Dev](https://hf.co/black-forest-labs/FLUX.1-dev) is made up of two text encoders - [T5-XXL](https://hf.co/google/t5-v1_1-xxl) and [CLIP-L](https://hf.co/openai/clip-vit-large-patch14) - a [diffusion transformer](../api/models/flux_transformer), and a [VAE](../api/models/autoencoderkl). With a model this size, it can be challenging to run inference on consumer GPUs.
116+
117+
Model sharding is a technique that distributes models across GPUs when the models don't fit on a single GPU. The example below assumes two 16GB GPUs are available for inference.
118+
119+
Start by computing the text embeddings with the text encoders. Keep the text encoders on two GPUs by setting `device_map="balanced"`. The `balanced` strategy evenly distributes the model on all available GPUs. Use the `max_memory` parameter to allocate the maximum amount of memory for each text encoder on each GPU.
120+
121+
> [!TIP]
122+
> **Only** load the text encoders for this step! The diffusion transformer and VAE are loaded in a later step to preserve memory.
123+
124+
```py
125+
from diffusers import FluxPipeline
126+
import torch
127+
128+
prompt = "a photo of a dog with cat-like look"
129+
130+
pipeline = FluxPipeline.from_pretrained(
131+
"black-forest-labs/FLUX.1-dev",
132+
transformer=None,
133+
vae=None,
134+
device_map="balanced",
135+
max_memory={0: "16GB", 1: "16GB"},
136+
torch_dtype=torch.bfloat16
137+
)
138+
with torch.no_grad():
139+
print("Encoding prompts.")
140+
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
141+
prompt=prompt, prompt_2=None, max_sequence_length=512
142+
)
143+
```
144+
145+
Once the text embeddings are computed, remove them from the GPU to make space for the diffusion transformer.
146+
147+
```py
148+
import gc
149+
150+
def flush():
151+
gc.collect()
152+
torch.cuda.empty_cache()
153+
torch.cuda.reset_max_memory_allocated()
154+
torch.cuda.reset_peak_memory_stats()
155+
156+
del pipeline.text_encoder
157+
del pipeline.text_encoder_2
158+
del pipeline.tokenizer
159+
del pipeline.tokenizer_2
160+
del pipeline
161+
162+
flush()
163+
```
164+
165+
Load the diffusion transformer next which has 12.5B parameters. This time, set `device_map="auto"` to automatically distribute the model across two 16GB GPUs. The `auto` strategy is backed by [Accelerate](https://hf.co/docs/accelerate/index) and available as a part of the [Big Model Inference](https://hf.co/docs/accelerate/concept_guides/big_model_inference) feature. It starts by distributing a model across the fastest device first (GPU) before moving to slower devices like the CPU and hard drive if needed. The trade-off of storing model parameters on slower devices is slower inference latency.
166+
167+
```py
168+
from diffusers import FluxTransformer2DModel
169+
import torch
170+
171+
transformer = FluxTransformer2DModel.from_pretrained(
172+
"black-forest-labs/FLUX.1-dev",
173+
subfolder="transformer",
174+
device_map="auto",
175+
torch_dtype=torch.bfloat16
176+
)
177+
```
178+
179+
> [!TIP]
180+
> At any point, you can try `print(pipeline.hf_device_map)` to see how the various models are distributed across devices. This is useful for tracking the device placement of the models.
181+
182+
Add the transformer model to the pipeline for denoising, but set the other model-level components like the text encoders and VAE to `None` because you don't need them yet.
183+
184+
```py
185+
pipeline = FluxPipeline.from_pretrained(
186+
"black-forest-labs/FLUX.1-dev", ,
187+
text_encoder=None,
188+
text_encoder_2=None,
189+
tokenizer=None,
190+
tokenizer_2=None,
191+
vae=None,
192+
transformer=transformer,
193+
torch_dtype=torch.bfloat16
194+
)
195+
196+
print("Running denoising.")
197+
height, width = 768, 1360
198+
latents = pipeline(
199+
prompt_embeds=prompt_embeds,
200+
pooled_prompt_embeds=pooled_prompt_embeds,
201+
num_inference_steps=50,
202+
guidance_scale=3.5,
203+
height=height,
204+
width=width,
205+
output_type="latent",
206+
).images
207+
```
208+
209+
Remove the pipeline and transformer from memory as they're no longer needed.
210+
211+
```py
212+
del pipeline.transformer
213+
del pipeline
214+
215+
flush()
216+
```
217+
218+
Finally, decode the latents with the VAE into an image. The VAE is typically small enough to be loaded on a single GPU.
219+
220+
```py
221+
from diffusers import AutoencoderKL
222+
from diffusers.image_processor import VaeImageProcessor
223+
import torch
224+
225+
vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16).to("cuda")
226+
vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
227+
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
228+
229+
with torch.no_grad():
230+
print("Running decoding.")
231+
latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor)
232+
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
233+
234+
image = vae.decode(latents, return_dict=False)[0]
235+
image = image_processor.postprocess(image, output_type="pil")
236+
image[0].save("split_transformer.png")
237+
```
238+
239+
By selectively loading and unloading the models you need at a given stage and sharding the largest models across multiple GPUs, it is possible to run inference with large models on consumer GPUs.

0 commit comments

Comments
 (0)