Skip to content

Commit ca3ab61

Browse files
Merge branch 'huggingface:main' into glm
2 parents a10c830 + 8f15be1 commit ca3ab61

32 files changed

+5522
-9
lines changed

docs/source/en/_toctree.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,8 @@
290290
title: CogView4Transformer2DModel
291291
- local: api/models/dit_transformer2d
292292
title: DiTTransformer2DModel
293+
- local: api/models/easyanimate_transformer3d
294+
title: EasyAnimateTransformer3DModel
293295
- local: api/models/flux_transformer
294296
title: FluxTransformer2DModel
295297
- local: api/models/hunyuan_transformer2d
@@ -352,6 +354,8 @@
352354
title: AutoencoderKLHunyuanVideo
353355
- local: api/models/autoencoderkl_ltx_video
354356
title: AutoencoderKLLTXVideo
357+
- local: api/models/autoencoderkl_magvit
358+
title: AutoencoderKLMagvit
355359
- local: api/models/autoencoderkl_mochi
356360
title: AutoencoderKLMochi
357361
- local: api/models/autoencoder_kl_wan
@@ -430,6 +434,8 @@
430434
title: DiffEdit
431435
- local: api/pipelines/dit
432436
title: DiT
437+
- local: api/pipelines/easyanimate
438+
title: EasyAnimate
433439
- local: api/pipelines/flux
434440
title: Flux
435441
- local: api/pipelines/control_flux_inpaint
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# AutoencoderKLMagvit
13+
14+
The 3D variational autoencoder (VAE) model with KL loss used in [EasyAnimate](https://github.com/aigc-apps/EasyAnimate) was introduced by Alibaba PAI.
15+
16+
The model can be loaded with the following code snippet.
17+
18+
```python
19+
from diffusers import AutoencoderKLMagvit
20+
21+
vae = AutoencoderKLMagvit.from_pretrained("alibaba-pai/EasyAnimateV5.1-12b-zh", subfolder="vae", torch_dtype=torch.float16).to("cuda")
22+
```
23+
24+
## AutoencoderKLMagvit
25+
26+
[[autodoc]] AutoencoderKLMagvit
27+
- decode
28+
- encode
29+
- all
30+
31+
## AutoencoderKLOutput
32+
33+
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
34+
35+
## DecoderOutput
36+
37+
[[autodoc]] models.autoencoders.vae.DecoderOutput
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# EasyAnimateTransformer3DModel
13+
14+
A Diffusion Transformer model for 3D data from [EasyAnimate](https://github.com/aigc-apps/EasyAnimate) was introduced by Alibaba PAI.
15+
16+
The model can be loaded with the following code snippet.
17+
18+
```python
19+
from diffusers import EasyAnimateTransformer3DModel
20+
21+
transformer = EasyAnimateTransformer3DModel.from_pretrained("alibaba-pai/EasyAnimateV5.1-12b-zh", subfolder="transformer", torch_dtype=torch.float16).to("cuda")
22+
```
23+
24+
## EasyAnimateTransformer3DModel
25+
26+
[[autodoc]] EasyAnimateTransformer3DModel
27+
28+
## Transformer2DModelOutput
29+
30+
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
-->
15+
16+
# EasyAnimate
17+
[EasyAnimate](https://github.com/aigc-apps/EasyAnimate) by Alibaba PAI.
18+
19+
The description from it's GitHub page:
20+
*EasyAnimate is a pipeline based on the transformer architecture, designed for generating AI images and videos, and for training baseline models and Lora models for Diffusion Transformer. We support direct prediction from pre-trained EasyAnimate models, allowing for the generation of videos with various resolutions, approximately 6 seconds in length, at 8fps (EasyAnimateV5.1, 1 to 49 frames). Additionally, users can train their own baseline and Lora models for specific style transformations.*
21+
22+
This pipeline was contributed by [bubbliiiing](https://github.com/bubbliiiing). The original codebase can be found [here](https://huggingface.co/alibaba-pai). The original weights can be found under [hf.co/alibaba-pai](https://huggingface.co/alibaba-pai).
23+
24+
There are two official EasyAnimate checkpoints for text-to-video and video-to-video.
25+
26+
| checkpoints | recommended inference dtype |
27+
|:---:|:---:|
28+
| [`alibaba-pai/EasyAnimateV5.1-12b-zh`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh) | torch.float16 |
29+
| [`alibaba-pai/EasyAnimateV5.1-12b-zh-InP`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-InP) | torch.float16 |
30+
31+
There is one official EasyAnimate checkpoints available for image-to-video and video-to-video.
32+
33+
| checkpoints | recommended inference dtype |
34+
|:---:|:---:|
35+
| [`alibaba-pai/EasyAnimateV5.1-12b-zh-InP`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-InP) | torch.float16 |
36+
37+
There are two official EasyAnimate checkpoints available for control-to-video.
38+
39+
| checkpoints | recommended inference dtype |
40+
|:---:|:---:|
41+
| [`alibaba-pai/EasyAnimateV5.1-12b-zh-Control`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-Control) | torch.float16 |
42+
| [`alibaba-pai/EasyAnimateV5.1-12b-zh-Control-Camera`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-Control-Camera) | torch.float16 |
43+
44+
For the EasyAnimateV5.1 series:
45+
- Text-to-video (T2V) and Image-to-video (I2V) works for multiple resolutions. The width and height can vary from 256 to 1024.
46+
- Both T2V and I2V models support generation with 1~49 frames and work best at this value. Exporting videos at 8 FPS is recommended.
47+
48+
## Quantization
49+
50+
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
51+
52+
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`EasyAnimatePipeline`] for inference with bitsandbytes.
53+
54+
```py
55+
import torch
56+
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, EasyAnimateTransformer3DModel, EasyAnimatePipeline
57+
from diffusers.utils import export_to_video
58+
59+
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
60+
transformer_8bit = EasyAnimateTransformer3DModel.from_pretrained(
61+
"alibaba-pai/EasyAnimateV5.1-12b-zh",
62+
subfolder="transformer",
63+
quantization_config=quant_config,
64+
torch_dtype=torch.float16,
65+
)
66+
67+
pipeline = EasyAnimatePipeline.from_pretrained(
68+
"alibaba-pai/EasyAnimateV5.1-12b-zh",
69+
transformer=transformer_8bit,
70+
torch_dtype=torch.float16,
71+
device_map="balanced",
72+
)
73+
74+
prompt = "A cat walks on the grass, realistic style."
75+
negative_prompt = "bad detailed"
76+
video = pipeline(prompt=prompt, negative_prompt=negative_prompt, num_frames=49, num_inference_steps=30).frames[0]
77+
export_to_video(video, "cat.mp4", fps=8)
78+
```
79+
80+
## EasyAnimatePipeline
81+
82+
[[autodoc]] EasyAnimatePipeline
83+
- all
84+
- __call__
85+
86+
## EasyAnimatePipelineOutput
87+
88+
[[autodoc]] pipelines.easyanimate.pipeline_output.EasyAnimatePipelineOutput

docs/source/en/using-diffusers/callback.md

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,84 @@ pipeline(
157157
)
158158
```
159159

160+
## IP Adapter Cutoff
161+
162+
IP Adapter is an image prompt adapter that can be used for diffusion models without any changes to the underlying model. We can use the IP Adapter Cutoff Callback to disable the IP Adapter after a certain number of steps. To set up the callback, you need to specify the number of denoising steps after which the callback comes into effect. You can do so by using either one of these two arguments:
163+
164+
- `cutoff_step_ratio`: Float number with the ratio of the steps.
165+
- `cutoff_step_index`: Integer number with the exact number of the step.
166+
167+
We need to download the diffusion model and load the ip_adapter for it as follows:
168+
169+
```py
170+
from diffusers import AutoPipelineForText2Image
171+
from diffusers.utils import load_image
172+
import torch
173+
174+
pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
175+
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
176+
pipeline.set_ip_adapter_scale(0.6)
177+
```
178+
The setup for the callback should look something like this:
179+
180+
```py
181+
182+
from diffusers import AutoPipelineForText2Image
183+
from diffusers.callbacks import IPAdapterScaleCutoffCallback
184+
from diffusers.utils import load_image
185+
import torch
186+
187+
188+
pipeline = AutoPipelineForText2Image.from_pretrained(
189+
"stabilityai/stable-diffusion-xl-base-1.0",
190+
torch_dtype=torch.float16
191+
).to("cuda")
192+
193+
194+
pipeline.load_ip_adapter(
195+
"h94/IP-Adapter",
196+
subfolder="sdxl_models",
197+
weight_name="ip-adapter_sdxl.bin"
198+
)
199+
200+
pipeline.set_ip_adapter_scale(0.6)
201+
202+
203+
callback = IPAdapterScaleCutoffCallback(
204+
cutoff_step_ratio=None,
205+
cutoff_step_index=5
206+
)
207+
208+
image = load_image(
209+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png"
210+
)
211+
212+
generator = torch.Generator(device="cuda").manual_seed(2628670641)
213+
214+
images = pipeline(
215+
prompt="a tiger sitting in a chair drinking orange juice",
216+
ip_adapter_image=image,
217+
negative_prompt="deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
218+
generator=generator,
219+
num_inference_steps=50,
220+
callback_on_step_end=callback,
221+
).images
222+
223+
images[0].save("custom_callback_img.png")
224+
```
225+
226+
<div class="flex gap-4">
227+
<div>
228+
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/without_callback.png" alt="generated image of a tiger sitting in a chair drinking orange juice" />
229+
<figcaption class="mt-2 text-center text-sm text-gray-500">without IPAdapterScaleCutoffCallback</figcaption>
230+
</div>
231+
<div>
232+
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/with_callback2.png" alt="generated image of a tiger sitting in a chair drinking orange juice with ip adapter callback" />
233+
<figcaption class="mt-2 text-center text-sm text-gray-500">with IPAdapterScaleCutoffCallback</figcaption>
234+
</div>
235+
</div>
236+
237+
160238
## Display image after each generation step
161239

162240
> [!TIP]

src/diffusers/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
"AutoencoderKLCogVideoX",
9595
"AutoencoderKLHunyuanVideo",
9696
"AutoencoderKLLTXVideo",
97+
"AutoencoderKLMagvit",
9798
"AutoencoderKLMochi",
9899
"AutoencoderKLTemporalDecoder",
99100
"AutoencoderKLWan",
@@ -109,6 +110,7 @@
109110
"ControlNetUnionModel",
110111
"ControlNetXSAdapter",
111112
"DiTTransformer2DModel",
113+
"EasyAnimateTransformer3DModel",
112114
"FluxControlNetModel",
113115
"FluxMultiControlNetModel",
114116
"FluxTransformer2DModel",
@@ -293,6 +295,9 @@
293295
"CogView4Pipeline",
294296
"ConsisIDPipeline",
295297
"CycleDiffusionPipeline",
298+
"EasyAnimateControlPipeline",
299+
"EasyAnimateInpaintPipeline",
300+
"EasyAnimatePipeline",
296301
"FluxControlImg2ImgPipeline",
297302
"FluxControlInpaintPipeline",
298303
"FluxControlNetImg2ImgPipeline",
@@ -620,6 +625,7 @@
620625
AutoencoderKLCogVideoX,
621626
AutoencoderKLHunyuanVideo,
622627
AutoencoderKLLTXVideo,
628+
AutoencoderKLMagvit,
623629
AutoencoderKLMochi,
624630
AutoencoderKLTemporalDecoder,
625631
AutoencoderKLWan,
@@ -635,6 +641,7 @@
635641
ControlNetUnionModel,
636642
ControlNetXSAdapter,
637643
DiTTransformer2DModel,
644+
EasyAnimateTransformer3DModel,
638645
FluxControlNetModel,
639646
FluxMultiControlNetModel,
640647
FluxTransformer2DModel,
@@ -798,6 +805,9 @@
798805
CogView4Pipeline,
799806
ConsisIDPipeline,
800807
CycleDiffusionPipeline,
808+
EasyAnimateControlPipeline,
809+
EasyAnimateInpaintPipeline,
810+
EasyAnimatePipeline,
801811
FluxControlImg2ImgPipeline,
802812
FluxControlInpaintPipeline,
803813
FluxControlNetImg2ImgPipeline,

src/diffusers/loaders/single_file_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1448,8 +1448,8 @@ def convert_open_clip_checkpoint(
14481448

14491449
if text_proj_key in checkpoint:
14501450
text_proj_dim = int(checkpoint[text_proj_key].shape[0])
1451-
elif hasattr(text_model.config, "projection_dim"):
1452-
text_proj_dim = text_model.config.projection_dim
1451+
elif hasattr(text_model.config, "hidden_size"):
1452+
text_proj_dim = text_model.config.hidden_size
14531453
else:
14541454
text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
14551455

src/diffusers/models/__init__.py

100644100755
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
3434
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
3535
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
36+
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
3637
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
3738
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
3839
_import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"]
@@ -72,6 +73,7 @@
7273
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
7374
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
7475
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
76+
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
7577
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
7678
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
7779
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
@@ -109,6 +111,7 @@
109111
AutoencoderKLCogVideoX,
110112
AutoencoderKLHunyuanVideo,
111113
AutoencoderKLLTXVideo,
114+
AutoencoderKLMagvit,
112115
AutoencoderKLMochi,
113116
AutoencoderKLTemporalDecoder,
114117
AutoencoderKLWan,
@@ -144,6 +147,7 @@
144147
ConsisIDTransformer3DModel,
145148
DiTTransformer2DModel,
146149
DualTransformer2DModel,
150+
EasyAnimateTransformer3DModel,
147151
FluxTransformer2DModel,
148152
HunyuanDiT2DModel,
149153
HunyuanVideoTransformer3DModel,

src/diffusers/models/attention_processor.py

100644100755
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,10 @@ def __init__(
274274
self.to_add_out = None
275275

276276
if qk_norm is not None and added_kv_proj_dim is not None:
277-
if qk_norm == "fp32_layer_norm":
277+
if qk_norm == "layer_norm":
278+
self.norm_added_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
279+
self.norm_added_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
280+
elif qk_norm == "fp32_layer_norm":
278281
self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
279282
self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
280283
elif qk_norm == "rms_norm":

src/diffusers/models/autoencoders/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
66
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
77
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
8+
from .autoencoder_kl_magvit import AutoencoderKLMagvit
89
from .autoencoder_kl_mochi import AutoencoderKLMochi
910
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
1011
from .autoencoder_kl_wan import AutoencoderKLWan

0 commit comments

Comments
 (0)