Skip to content

Commit 90409dd

Browse files
committed
feedback
1 parent f2d1133 commit 90409dd

File tree

1 file changed

+115
-0
lines changed

1 file changed

+115
-0
lines changed

docs/source/en/training/distributed_inference.md

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,121 @@ Call `torchrun` to run the inference script and use the `--nproc_per_node` argum
111111
torchrun run_distributed.py --nproc_per_node=2
112112
```
113113

114+
## device_map
115+
116+
The `device_map` argument enables distributed inference by automatically placing model components on separate GPUs. This is especially useful when a model doesn't fit on a single GPU. You can use `device_map` to selectively load and unload the required model components at a given stage as shown in the example below (assumes two GPUs are available).
117+
118+
Set `device_map="balanced"` to evenly distributes the text encoders on all available GPUs. You can use the `max_memory` argument to allocate a maximum amount of memory for each text encoder. Don't load any other pipeline components to avoid memory usage.
119+
120+
```py
121+
from diffusers import FluxPipeline
122+
import torch
123+
124+
prompt = """
125+
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
126+
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
127+
"""
128+
129+
pipeline = FluxPipeline.from_pretrained(
130+
"black-forest-labs/FLUX.1-dev",
131+
transformer=None,
132+
vae=None,
133+
device_map="balanced",
134+
max_memory={0: "16GB", 1: "16GB"},
135+
torch_dtype=torch.bfloat16
136+
)
137+
with torch.no_grad():
138+
print("Encoding prompts.")
139+
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
140+
prompt=prompt, prompt_2=None, max_sequence_length=512
141+
)
142+
```
143+
144+
After the text embeddings are computed, remove them from the GPU to make space for the diffusion transformer.
145+
146+
```py
147+
import gc
148+
149+
def flush():
150+
gc.collect()
151+
torch.cuda.empty_cache()
152+
torch.cuda.reset_max_memory_allocated()
153+
torch.cuda.reset_peak_memory_stats()
154+
155+
del pipeline.text_encoder
156+
del pipeline.text_encoder_2
157+
del pipeline.tokenizer
158+
del pipeline.tokenizer_2
159+
del pipeline
160+
161+
flush()
162+
```
163+
164+
Set `device_map="auto"` to automatically distribute the model on the two GPUs. This strategy places a model on the fastest device first before placing a model on a slower device like a CPU or hard drive if needed. The trade-off of storing model parameters on slower devices is slower inference latency.
165+
166+
```py
167+
from diffusers import AutoModel
168+
import torch
169+
170+
transformer = AutoModel.from_pretrained(
171+
"black-forest-labs/FLUX.1-dev",
172+
subfolder="transformer",
173+
device_map="auto",
174+
torch_dtype=torch.bfloat16
175+
)
176+
```
177+
178+
> [!TIP]
179+
> Run `pipeline.hf_device_map` to see how the various models are distributed across devices. This is useful for tracking model device placement. You can also call `hf_device_map` on the transformer model to see how it is distributed.
180+
181+
Add the transformer model to the pipeline and set the `output_type="latent"` to generate the latents.
182+
183+
```py
184+
pipeline = FluxPipeline.from_pretrained(
185+
"black-forest-labs/FLUX.1-dev",
186+
text_encoder=None,
187+
text_encoder_2=None,
188+
tokenizer=None,
189+
tokenizer_2=None,
190+
vae=None,
191+
transformer=transformer,
192+
torch_dtype=torch.bfloat16
193+
)
194+
195+
print("Running denoising.")
196+
height, width = 768, 1360
197+
latents = pipeline(
198+
prompt_embeds=prompt_embeds,
199+
pooled_prompt_embeds=pooled_prompt_embeds,
200+
num_inference_steps=50,
201+
guidance_scale=3.5,
202+
height=height,
203+
width=width,
204+
output_type="latent",
205+
).images
206+
```
207+
208+
Remove the pipeline and transformer from memory and load a VAE to decode the latents. The VAE is typically small enough to be loaded on a single device.
209+
210+
```py
211+
import torch
212+
from diffusers import AutoencoderKL
213+
from diffusers.image_processor import VaeImageProcessor
214+
215+
vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16).to("cuda")
216+
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
217+
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
218+
219+
with torch.no_grad():
220+
print("Running decoding.")
221+
latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor)
222+
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
223+
224+
image = vae.decode(latents, return_dict=False)[0]
225+
image = image_processor.postprocess(image, output_type="pil")
226+
image[0].save("split_transformer.png")
227+
```
228+
114229
## Resources
115230

116231
- Take a look at this [script](https://gist.github.com/sayakpaul/cfaebd221820d7b43fae638b4dfa01ba) for a minimal example of distributed inference with Accelerate.

0 commit comments

Comments
 (0)