Skip to content

Commit d3425a7

Browse files
committed
feedback
1 parent a48ac35 commit d3425a7

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

docs/source/en/training/distributed_inference.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ By selectively loading and unloading the models you need at a given stage and sh
242242

243243
[Context parallelism](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=context_parallelism) reduces memory by splitting input sequences across multiple GPUs. Each GPU processes its own slice of the sequence.
244244

245-
The key (K) and value (V) representations are communicated between devices with [Ring Attention](https://nanotron-ultrascale-playbook.static.hf.space/index.html?section=second_optimization%3A_bucketing_gradients#ring_attention) to ensure each split can see every other token's K/V. In Ring Attention, each GPU computes attention for it's local K/V and passes it to the next GPU in the ring. This way, no single GPU has to hold the full sequence and reduces communication latency.
245+
The key (K) and value (V) representations are communicated between devices with [Ring Attention](https://huggingface.co/papers/2310.01889) to ensure each split can see every other token's K/V. In Ring Attention, each GPU computes attention for it's local K/V and passes it to the next GPU in the ring. This way, no single GPU has to hold the full sequence and reduces communication latency.
246246

247247
Call [`parallelize`] on the model and pass a [`ContextParallelConfig`]. This config supports the `ring_degree` argument which determines the number of devices to use for Ring Attention.
248248

@@ -263,8 +263,7 @@ try:
263263
device = torch.device("cuda", rank % torch.cuda.device_count())
264264
torch.cuda.set_device(device)
265265

266-
pipeline = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16)
267-
pipeline.to("cuda")
266+
pipeline = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda")
268267

269268
pipeline.transformer.parallelize(config=ContextParallelConfig(ring_degree=2))
270269
pipeline.transformer.set_attention_backend("flash")
@@ -294,7 +293,7 @@ finally:
294293

295294
### Ulysses Attention
296295

297-
Ulysses Attention splits a sequence across GPUs and performs an *all-to-all* (every device sends/receives data to every other device) so that each GPU ends up with all the tokens for only a subset of the attention heads. Each GPU computes attention locally on all tokens for its head and then performs another all-to-all to regroup the results by tokens, making it ready for the next layer.
296+
[Ulysses Attention](https://huggingface.co/papers/2309.14509) splits a sequence across GPUs and performs an *all-to-all* (every device sends/receives data to every other device) so that each GPU ends up with all the tokens for only a subset of the attention heads. Each GPU computes attention locally on all tokens for its head and then performs another all-to-all to regroup the results by tokens, making it ready for the next layer.
298297

299298
[`ContextParallelConfig`] also supports Ulysses Attention through the `ulysses_degree` argument. This determines the number of devices to use for Ulysses Attention.
300299

0 commit comments

Comments
 (0)