Skip to content

Commit efef58c

Browse files
committed
cache
1 parent 20273e5 commit efef58c

File tree

3 files changed

+72
-56
lines changed

3 files changed

+72
-56
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@
178178
- sections:
179179
- local: optimization/fp16
180180
title: Accelerate inference
181+
- local: optimization/cache
182+
title: Caching
181183
- local: optimization/memory
182184
title: Reduce memory usage
183185
- local: optimization/xformers

docs/source/en/api/cache.md

Lines changed: 4 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -11,71 +11,19 @@ specific language governing permissions and limitations under the License. -->
1111

1212
# Caching methods
1313

14-
## Pyramid Attention Broadcast
14+
Cache methods speedup diffusion transformers by storing and reusing attention states instead of recalculating them.
1515

16-
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
17-
18-
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
19-
20-
Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request.
21-
22-
```python
23-
import torch
24-
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
25-
26-
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
27-
pipe.to("cuda")
28-
29-
# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
30-
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
31-
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
32-
# poorer quality of generated videos.
33-
config = PyramidAttentionBroadcastConfig(
34-
spatial_attention_block_skip_range=2,
35-
spatial_attention_timestep_skip_range=(100, 800),
36-
current_timestep_callback=lambda: pipe.current_timestep,
37-
)
38-
pipe.transformer.enable_cache(config)
39-
```
40-
41-
## Faster Cache
42-
43-
[FasterCache](https://huggingface.co/papers/2410.19355) from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong.
44-
45-
FasterCache is a method that speeds up inference in diffusion transformers by:
46-
- Reusing attention states between successive inference steps, due to high similarity between them
47-
- Skipping unconditional branch prediction used in classifier-free guidance by revealing redundancies between unconditional and conditional branch outputs for the same timestep, and therefore approximating the unconditional branch output using the conditional branch output
48-
49-
```python
50-
import torch
51-
from diffusers import CogVideoXPipeline, FasterCacheConfig
52-
53-
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
54-
pipe.to("cuda")
55-
56-
config = FasterCacheConfig(
57-
spatial_attention_block_skip_range=2,
58-
spatial_attention_timestep_skip_range=(-1, 681),
59-
current_timestep_callback=lambda: pipe.current_timestep,
60-
attention_weight_callback=lambda _: 0.3,
61-
unconditional_batch_skip_range=5,
62-
unconditional_batch_timestep_skip_range=(-1, 781),
63-
tensor_format="BFCHW",
64-
)
65-
pipe.transformer.enable_cache(config)
66-
```
67-
68-
### CacheMixin
16+
## CacheMixin
6917

7018
[[autodoc]] CacheMixin
7119

72-
### PyramidAttentionBroadcastConfig
20+
## PyramidAttentionBroadcastConfig
7321

7422
[[autodoc]] PyramidAttentionBroadcastConfig
7523

7624
[[autodoc]] apply_pyramid_attention_broadcast
7725

78-
### FasterCacheConfig
26+
## FasterCacheConfig
7927

8028
[[autodoc]] FasterCacheConfig
8129

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# Caching
13+
14+
Caching accelerates inference by storing and reusing redundant attention outputs instead of performing extra computation. It significantly improves efficiency and doesn't require additional training.
15+
16+
This guide shows you how to use the caching methods supported in Diffusers.
17+
18+
## Pyramid Attention Broadcast
19+
20+
[Pyramid Attention Broadcast (PAB)](https://huggingface.co/papers/2408.12588) is based on the observation that many of the attention output differences are redundant. The attention differences are smallest in the cross attention block so the cached attention states are broadcasted and reused over a longer range. This is followed by temporal attention and finally spatial attention.
21+
22+
PAB can be combined with other techniques like sequence parallelism and classifier-free guidance parallelism for near real-time video generation.
23+
24+
Set up and pass a [`PyramidAttentionBroadcastConfig`] to a pipeline's transformer to enable it. The `spatial_attention_block_skip_range` controls how often to skip attention calculations in the spatial attention blocks and the `spatial_attention_timestep_skip_range` is the range of timesteps to skip. Take care to choose an appropriate range because a smaller interval can lead to slower inference speeds and a larger interval can result in lower generation quality.
25+
26+
```python
27+
import torch
28+
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
29+
30+
pipeline = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
31+
pipeline.to("cuda")
32+
33+
config = PyramidAttentionBroadcastConfig(
34+
spatial_attention_block_skip_range=2,
35+
spatial_attention_timestep_skip_range=(100, 800),
36+
current_timestep_callback=lambda: pipe.current_timestep,
37+
)
38+
pipeline.transformer.enable_cache(config)
39+
```
40+
41+
## FasterCache
42+
43+
[FasterCache](https://huggingface.co/papers/2410.19355) computes and caches attention features at every other timestep instead of directly reusing cached features because it can cause flickering or blurry details in the generated video. The features from the skipped step are calculated from the difference between the adjacent cached features.
44+
45+
FasterCache also uses a classifier-free guidance (CFG) cache which computes both the conditional and unconditional outputs once. For future timesteps, only the conditional output is calculated and the unconditional output is estimated from the cached biases.
46+
47+
Set up and pass a [`FasterCacheConfig`] to a pipeline's transformer to enable it.
48+
49+
```python
50+
import torch
51+
from diffusers import CogVideoXPipeline, FasterCacheConfig
52+
53+
pipe line= CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
54+
pipeline.to("cuda")
55+
56+
config = FasterCacheConfig(
57+
spatial_attention_block_skip_range=2,
58+
spatial_attention_timestep_skip_range=(-1, 681),
59+
current_timestep_callback=lambda: pipe.current_timestep,
60+
attention_weight_callback=lambda _: 0.3,
61+
unconditional_batch_skip_range=5,
62+
unconditional_batch_timestep_skip_range=(-1, 781),
63+
tensor_format="BFCHW",
64+
)
65+
pipeline.transformer.enable_cache(config)
66+
```

0 commit comments

Comments
 (0)