You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/en/optimization/attention_backends.md
+10-2Lines changed: 10 additions & 2 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -11,7 +11,7 @@ specific language governing permissions and limitations under the License. -->
11
11
12
12
# Attention backends
13
13
14
-
> [!TIP]
14
+
> [!NOTE]
15
15
> The attention dispatcher is an experimental feature. Please open an issue if you have any feedback or encounter any problems.
16
16
17
17
Diffusers provides several optimized attention algorithms that are more memory and computationally efficient through it's *attention dispatcher*. The dispatcher acts as a router for managing and switching between different attention implementations and provides a unified interface for interacting with them.
@@ -33,7 +33,7 @@ The [`~ModelMixin.set_attention_backend`] method iterates through all the module
33
33
34
34
The example below demonstrates how to enable the `_flash_3_hub` implementation for FlashAttention-3 from the [kernel](https://github.com/huggingface/kernels) library, which allows you to instantly use optimized compute kernels from the Hub without requiring any setup.
35
35
36
-
> [!TIP]
36
+
> [!NOTE]
37
37
> FlashAttention-3 is not supported for non-Hopper architectures, in which case, use FlashAttention with `set_attention_backend("flash")`.
38
38
39
39
```py
@@ -78,10 +78,16 @@ with attention_backend("_flash_3_hub"):
78
78
image = pipeline(prompt).images[0]
79
79
```
80
80
81
+
> [!TIP]
82
+
> Most attention backends support `torch.compile` without graph breaks and can be used to further speed up inference.
83
+
81
84
## Available backends
82
85
83
86
Refer to the table below for a complete list of available attention backends and their variants.
84
87
88
+
<details>
89
+
<summary>Expand</summary>
90
+
85
91
| Backend Name | Family | Description |
86
92
|--------------|--------|-------------|
87
93
|`native`|[PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend)| Default backend using PyTorch's scaled_dot_product_attention |
@@ -104,3 +110,5 @@ Refer to the table below for a complete list of available attention backends and
Copy file name to clipboardExpand all lines: docs/source/en/quantization/torchao.md
+66-39Lines changed: 66 additions & 39 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -11,69 +11,96 @@ specific language governing permissions and limitations under the License. -->
11
11
12
12
# torchao
13
13
14
-
[TorchAO](https://github.com/pytorch/ao)is an architecture optimization library for PyTorch. It provides high-performance dtypes, optimization techniques, and kernels for inference and training, featuring composability with native PyTorch features like [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), FullyShardedDataParallel (FSDP), and more.
14
+
[torchao](https://github.com/pytorch/ao) provides high-performance dtypes and optimizations based on quantization and sparsity for inference and training PyTorch models. It is supported for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
15
15
16
-
Before you begin, make sure you have Pytorch 2.5+ and TorchAO installed.
16
+
Make sure Pytorch 2.5+ and torchao are installed with the command below.
17
17
18
18
```bash
19
-
pip install -U torch torchao
19
+
uv pip install -U torch torchao
20
20
```
21
21
22
+
Each quantization dtype is available as a separate instance of a [AOBaseConfig](https://docs.pytorch.org/ao/main/api_ref_quantization.html#inference-apis-for-quantize) class. This provides more flexible configuration options by exposing more available arguments.
22
23
23
-
Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
24
+
Pass the `AOBaseConfig` of a quantization dtype, like [Int4WeightOnlyConfig](https://docs.pytorch.org/ao/main/generated/torchao.quantization.Int4WeightOnlyConfig) to [`TorchAoConfig`] in [`~ModelMixin.from_pretrained`].
24
25
25
-
The example below only quantizes the weights to int8.
26
-
27
-
```python
26
+
```py
28
27
import torch
29
-
from diffusers import FluxPipeline, AutoModel, TorchAoConfig
30
-
31
-
model_id ="black-forest-labs/FLUX.1-dev"
32
-
dtype = torch.bfloat16
28
+
from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig
29
+
from torchao.quantization import Int8WeightOnlyConfig
TorchAO is fully compatible with [torch.compile](../optimization/fp16#torchcompile), setting it apart from other quantization methods. This makes it easy to speed up inference with just one line of code.
59
+
## torch.compile
60
+
61
+
torchao supports [torch.compile](../optimization/fp16#torchcompile) which can speed up inference with one line of code.
60
62
61
63
```python
62
-
# In the above code, add the following after initializing the transformer
For speed and memory benchmarks on Flux and CogVideoX, please refer to the table [here](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450). You can also find some torchao [benchmarks](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks) numbers for various hardware.
81
+
Refer to this [table](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450) for inference speed and memory usage benchmarks with Flux and CogVideoX. More benchmarks on various hardware are also available in the torchao [repository](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks).
67
82
68
83
> [!TIP]
69
84
> The FP8 post-training quantization schemes in torchao are effective for GPUs with compute capability of at least 8.9 (RTX-4090, Hopper, etc.). FP8 often provides the best speed, memory, and quality trade-off when generating images and videos. We recommend combining FP8 and torch.compile if your GPU is compatible.
70
85
71
-
torchao also supports an automatic quantization API through [autoquant](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. Currently, this can be used directly on the underlying modeling components. Diffusers will also expose an autoquant configuration option in the future.
86
+
## autoquant
87
+
88
+
torchao provides [autoquant](https://docs.pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) an automatic quantization API. Autoquantization chooses the best quantization strategy by comparing the performance of each strategy on chosen input types and shapes. This is only supported in Diffusers for individual models at the moment.
89
+
90
+
```py
91
+
import torch
92
+
from diffusers import DiffusionPipeline
93
+
from torchao.quantization import autoquant
94
+
95
+
# Load the pipeline
96
+
pipeline = DiffusionPipeline.from_pretrained(
97
+
"black-forest-labs/FLUX.1-schnell",
98
+
torch_dtype=torch.bfloat16,
99
+
device_map="cuda"
100
+
)
72
101
73
-
The `TorchAoConfig` class accepts three parameters:
74
-
-`quant_type`: A string value mentioning one of the quantization types below.
75
-
-`modules_to_not_convert`: A list of module full/partial module names for which quantization should not be performed. For example, to not perform any quantization of the [`FluxTransformer2DModel`]'s first block, one would specify: `modules_to_not_convert=["single_transformer_blocks.0"]`.
76
-
-`kwargs`: A dict of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`.
0 commit comments