Skip to content

Commit 6b58810

Browse files
committed
feedback
1 parent 2e1ed24 commit 6b58810

File tree

1 file changed

+35
-129
lines changed

1 file changed

+35
-129
lines changed

docs/source/en/optimization/attention_backends.md

Lines changed: 35 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,16 @@ Available attention implementations include the following.
2222
| PyTorch native | built-in PyTorch implementation using [scaled_dot_product_attention](./fp16#scaled-dot-product-attention) |
2323
| xFormers | memory-efficient attention with support for various attention kernels |
2424

25-
This guide will show you how to use the dispatcher to set and use the different attention backends.
25+
This guide will show you how to set and use the different attention backends.
2626

27-
## FlashAttention
27+
## set_attention_backend
2828

29-
[FlashAttention](https://github.com/Dao-AILab/flash-attention) reduces memory traffic by making better use of on-chip shared memory (SRAM) instead of global GPU memory so the data doesn't have to travel far. The latest variant, FlashAttention-3, is further optimized for modern GPUs (Hopper/Blackwell) and also overlaps computations and handles FP8 attention better.
29+
The [`~ModelMixin.set_attention_backend`] method iterates through all the modules in the model and sets the appropriate attention backend to use. The attention backend setting persists until [`~ModelMixin.reset_attention_backend`] is called.
3030

31-
There are several available FlashAttention variants, including variable length and the original FlashAttention. For a full list of supported implementations, check the list [here](https://github.com/huggingface/diffusers/blob/5e181eddfe7e44c1444a2511b0d8e21d177850a0/src/diffusers/models/attention_dispatch.py#L163).
31+
The example below demonstrates how to enable the `_flash_3_hub` implementation for FlashAttention-3 from the [kernel](https://github.com/huggingface/kernels) library, which allows you to instantly use optimized compute kernels from the Hub without requiring any setup.
3232

33-
The example below demonstrates how to enable the `_flash_3_hub` implementation. The [kernel](https://github.com/huggingface/kernels) library allows you to instantly use optimized compute kernels from the Hub without requiring any setup.
34-
35-
Pass the attention backend to the [`~ModelMixin.set_attention_backend`] method.
33+
> [!TIP]
34+
> FlashAttention-3 is not supported for non-Hopper architectures, in which case, use FlashAttention (set_attention_backend("flash")).
3635
3736
```py
3837
import torch
@@ -44,129 +43,15 @@ pipeline = QwenImagePipeline.from_pretrained(
4443
pipeline.transformer.set_attention_backend("_flash_3_hub")
4544
```
4645

47-
You could also use the [attention_backend](https://github.com/huggingface/diffusers/blob/5e181eddfe7e44c1444a2511b0d8e21d177850a0/src/diffusers/models/attention_dispatch.py#L225) context manager to temporarily set an attention backend for a model within the context.
48-
49-
```py
50-
import torch
51-
from diffusers import QwenImagePipeline
52-
53-
pipeline = QwenImagePipeline.from_pretrained(
54-
"Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda"
55-
)
56-
prompt = """
57-
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
58-
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
59-
"""
60-
61-
with attention_backend("_flash_3_hub"):
62-
image = pipeline(prompt).images[0]
63-
```
64-
65-
To restore the default attention backend, call [`~ModelMixin.reset_attention_backend`].
66-
67-
```py
68-
pipeline.transformer.reset_attention_backend()
69-
```
70-
71-
## SageAttention
72-
73-
[SageAttention](https://github.com/thu-ml/SageAttention) quantizes attention by computing queries (Q) and keys (K) in INT8. The probability (P) and value (V) are calculated in either FP8 or FP16 to minimize error. This significantly increases inference throughput and with little to no degradation.
74-
75-
There are several SageAttention variants for FP8 and FP16 as well as whether it is CUDA or Triton based. For a full list of supported implementations, check the list [here](https://github.com/huggingface/diffusers/blob/5e181eddfe7e44c1444a2511b0d8e21d177850a0/src/diffusers/models/attention_dispatch.py#L182).
76-
77-
The example below uses the `_sage_qk_int8_pv_fp8_cuda` implementation.
78-
79-
```py
80-
import torch
81-
from diffusers import QwenImagePipeline
82-
83-
pipeline = QwenImagePipeline.from_pretrained(
84-
"Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda"
85-
)
86-
pipeline.transformer.set_attention_backend("_sage_qk_int8_pv_fp8_cuda")
87-
```
88-
89-
You could also use the [attention_backend](https://github.com/huggingface/diffusers/blob/5e181eddfe7e44c1444a2511b0d8e21d177850a0/src/diffusers/models/attention_dispatch.py#L225) context manager to temporarily set an attention backend for a model within the context.
90-
91-
```py
92-
import torch
93-
from diffusers import QwenImagePipeline
94-
95-
pipeline = QwenImagePipeline.from_pretrained(
96-
"Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda"
97-
)
98-
prompt = """
99-
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
100-
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
101-
"""
102-
103-
with attention_backend("_sage_qk_int8_pv_fp8_cuda"):
104-
image = pipeline(prompt).images[0]
105-
```
106-
107-
To restore the default attention backend, call [`~ModelMixin.reset_attention_backend`].
108-
109-
```py
110-
pipeline.transformer.reset_attention_backend()
111-
```
112-
113-
## PyTorch native
114-
115-
PyTorch includes a [native implementation](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) of several optimized attention implementations including [FlexAttention](https://pytorch.org/blog/flexattention/), FlashAttention, memory-efficient attention, and a C++ version.
116-
117-
For a full list of supported implementations, check the list [here](https://github.com/huggingface/diffusers/blob/5e181eddfe7e44c1444a2511b0d8e21d177850a0/src/diffusers/models/attention_dispatch.py#L171).
118-
119-
The example below uses the `_native_flash` implementation.
120-
121-
```py
122-
import torch
123-
from diffusers import QwenImagePipeline
124-
125-
pipeline = QwenImagePipeline.from_pretrained(
126-
"Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda"
127-
)
128-
pipeline.transformer.set_attention_backend("_native_flash")
129-
```
130-
131-
You could also use the [attention_backend](https://github.com/huggingface/diffusers/blob/5e181eddfe7e44c1444a2511b0d8e21d177850a0/src/diffusers/models/attention_dispatch.py#L225) context manager to temporarily set an attention backend for a model within the context.
132-
133-
```py
134-
import torch
135-
from diffusers import QwenImagePipeline
136-
137-
pipeline = QwenImagePipeline.from_pretrained(
138-
"Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda"
139-
)
140-
prompt = """
141-
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
142-
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
143-
"""
144-
145-
with attention_backend("_native_flash"):
146-
image = pipeline(prompt).images[0]
147-
```
148-
14946
To restore the default attention backend, call [`~ModelMixin.reset_attention_backend`].
15047

15148
```py
15249
pipeline.transformer.reset_attention_backend()
15350
```
15451

155-
## xFormers
156-
157-
[xFormers](https://github.com/facebookresearch/xformers) provides memory-efficient attention algorithms such as sparse attention and block-sparse attention. Pass `xformers` to enable it.
158-
159-
```py
160-
import torch
161-
from diffusers import QwenImagePipeline
162-
163-
pipeline = QwenImagePipeline.from_pretrained(
164-
"Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda"
165-
)
166-
pipeline.transformer.set_attention_backend("xformers")
167-
```
52+
## attention_backend context manager
16853

169-
You could also use the [attention_backend](https://github.com/huggingface/diffusers/blob/5e181eddfe7e44c1444a2511b0d8e21d177850a0/src/diffusers/models/attention_dispatch.py#L225) context manager to temporarily set an attention backend for a model within the context.
54+
The [attention_backend](https://github.com/huggingface/diffusers/blob/5e181eddfe7e44c1444a2511b0d8e21d177850a0/src/diffusers/models/attention_dispatch.py#L225) context manager temporarily sets an attention backend for a model within the context. Outside the context, the default attention (PyTorch's native scaled dot product attention) is used. This is useful if you want to use different backends for different parts of a pipeline or if you want to test the different backends.
17055

17156
```py
17257
import torch
@@ -180,12 +65,33 @@ cinematic film still of a cat sipping a margarita in a pool in Palm Springs, Cal
18065
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
18166
"""
18267

183-
with attention_backend("xformers"):
68+
with attention_backend("_flash_3_hub"):
18469
image = pipeline(prompt).images[0]
18570
```
18671

187-
To restore the default attention backend, call [`~ModelMixin.reset_attention_backend`].
188-
189-
```py
190-
pipeline.transformer.reset_attention_backend()
191-
```
72+
## Available backends
73+
74+
Refer to the table below for available attention backends.
75+
76+
| Backend Name | Family | Description |
77+
|--------------|--------|-------------|
78+
| `native` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | Default backend using PyTorch's scaled_dot_product_attention |
79+
| `flex` | [FlexAttention](https://docs.pytorch.org/docs/stable/nn.attention.flex_attention.html#module-torch.nn.attention.flex_attention) | PyTorch FlexAttention implementation |
80+
| `_native_cudnn` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | CuDNN-optimized attention |
81+
| `_native_efficient` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | Memory-efficient attention |
82+
| `_native_flash` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | PyTorch's FlashAttention |
83+
| `_native_math` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | Math-based attention (fallback) |
84+
| `_native_npu` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | NPU-optimized attention |
85+
| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
86+
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
87+
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
88+
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
89+
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
90+
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
91+
| `sage` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) |
92+
| `sage_varlen` | [SageAttention](https://github.com/thu-ml/SageAttention) | Variable length SageAttention |
93+
| `_sage_qk_int8_pv_fp8_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (CUDA) |
94+
| `_sage_qk_int8_pv_fp8_cuda_sm90` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (SM90) |
95+
| `_sage_qk_int8_pv_fp16_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP16 PV (CUDA) |
96+
| `_sage_qk_int8_pv_fp16_triton` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP16 PV (Triton) |
97+
| `xformers` | [xFormers](https://github.com/facebookresearch/xformers) | Memory-efficient attention |

0 commit comments

Comments
 (0)