You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/en/training/distributed_inference.md
+115Lines changed: 115 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -111,6 +111,121 @@ Call `torchrun` to run the inference script and use the `--nproc_per_node` argum
111
111
torchrun run_distributed.py --nproc_per_node=2
112
112
```
113
113
114
+
## device_map
115
+
116
+
The `device_map` argument enables distributed inference by automatically placing model components on separate GPUs. This is especially useful when a model doesn't fit on a single GPU. You can use `device_map` to selectively load and unload the required model components at a given stage as shown in the example below (assumes two GPUs are available).
117
+
118
+
Set `device_map="balanced"` to evenly distributes the text encoders on all available GPUs. You can use the `max_memory` argument to allocate a maximum amount of memory for each text encoder. Don't load any other pipeline components to avoid memory usage.
119
+
120
+
```py
121
+
from diffusers import FluxPipeline
122
+
import torch
123
+
124
+
prompt ="""
125
+
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
126
+
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
After the text embeddings are computed, remove them from the GPU to make space for the diffusion transformer.
145
+
146
+
```py
147
+
import gc
148
+
149
+
defflush():
150
+
gc.collect()
151
+
torch.cuda.empty_cache()
152
+
torch.cuda.reset_max_memory_allocated()
153
+
torch.cuda.reset_peak_memory_stats()
154
+
155
+
del pipeline.text_encoder
156
+
del pipeline.text_encoder_2
157
+
del pipeline.tokenizer
158
+
del pipeline.tokenizer_2
159
+
del pipeline
160
+
161
+
flush()
162
+
```
163
+
164
+
Set `device_map="auto"` to automatically distribute the model on the two GPUs. This strategy places a model on the fastest device first before placing a model on a slower device like a CPU or hard drive if needed. The trade-off of storing model parameters on slower devices is slower inference latency.
165
+
166
+
```py
167
+
from diffusers import AutoModel
168
+
import torch
169
+
170
+
transformer = AutoModel.from_pretrained(
171
+
"black-forest-labs/FLUX.1-dev",
172
+
subfolder="transformer",
173
+
device_map="auto",
174
+
torch_dtype=torch.bfloat16
175
+
)
176
+
```
177
+
178
+
> [!TIP]
179
+
> Run `pipeline.hf_device_map` to see how the various models are distributed across devices. This is useful for tracking model device placement. You can also call `hf_device_map` on the transformer model to see how it is distributed.
180
+
181
+
Add the transformer model to the pipeline and set the `output_type="latent"` to generate the latents.
182
+
183
+
```py
184
+
pipeline = FluxPipeline.from_pretrained(
185
+
"black-forest-labs/FLUX.1-dev",
186
+
text_encoder=None,
187
+
text_encoder_2=None,
188
+
tokenizer=None,
189
+
tokenizer_2=None,
190
+
vae=None,
191
+
transformer=transformer,
192
+
torch_dtype=torch.bfloat16
193
+
)
194
+
195
+
print("Running denoising.")
196
+
height, width =768, 1360
197
+
latents = pipeline(
198
+
prompt_embeds=prompt_embeds,
199
+
pooled_prompt_embeds=pooled_prompt_embeds,
200
+
num_inference_steps=50,
201
+
guidance_scale=3.5,
202
+
height=height,
203
+
width=width,
204
+
output_type="latent",
205
+
).images
206
+
```
207
+
208
+
Remove the pipeline and transformer from memory and load a VAE to decode the latents. The VAE is typically small enough to be loaded on a single device.
209
+
210
+
```py
211
+
import torch
212
+
from diffusers import AutoencoderKL
213
+
from diffusers.image_processor import VaeImageProcessor
- Take a look at this [script](https://gist.github.com/sayakpaul/cfaebd221820d7b43fae638b4dfa01ba) for a minimal example of distributed inference with Accelerate.
0 commit comments