Skip to content

Commit acea8cd

Browse files
J4BEZDN6github-actions[bot]
authored andcommitted
[From Single File] support from_single_file method for WanVACE3DTransformer (huggingface#11807)
* add `WandVACETransformer3DModel` in`SINGLE_FILE_LOADABLE_CLASSES` * add rename keys for `VACE` add rename keys for `VACE` * fix typo Sincere thanks to @nitinmukesh 🙇‍♂️ * support for `1.3B VACE` model Sincere thanks to @nitinmukesh again🙇‍♂️ * update * update * Apply style fixes --------- Co-authored-by: Dhruv Nair <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 555d82a commit acea8cd

File tree

3 files changed

+87
-1
lines changed

3 files changed

+87
-1
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@
136136
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
137137
"default_subfolder": "transformer",
138138
},
139+
"WanVACETransformer3DModel": {
140+
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
141+
"default_subfolder": "transformer",
142+
},
139143
"AutoencoderKLWan": {
140144
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
141145
"default_subfolder": "vae",

src/diffusers/loaders/single_file_utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@
126126
],
127127
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
128128
"wan_vae": "decoder.middle.0.residual.0.gamma",
129+
"wan_vace": "vace_blocks.0.after_proj.bias",
129130
"hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
130131
"cosmos-1.0": [
131132
"net.x_embedder.proj.1.weight",
@@ -202,6 +203,8 @@
202203
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
203204
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
204205
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
206+
"wan-vace-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-1.3B-diffusers"},
207+
"wan-vace-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-14B-diffusers"},
205208
"hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
206209
"cosmos-1.0-t2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Text2World"},
207210
"cosmos-1.0-t2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Text2World"},
@@ -716,7 +719,13 @@ def infer_diffusers_model_type(checkpoint):
716719
else:
717720
target_key = "patch_embedding.weight"
718721

719-
if checkpoint[target_key].shape[0] == 1536:
722+
if CHECKPOINT_KEY_NAMES["wan_vace"] in checkpoint:
723+
if checkpoint[target_key].shape[0] == 1536:
724+
model_type = "wan-vace-1.3B"
725+
elif checkpoint[target_key].shape[0] == 5120:
726+
model_type = "wan-vace-14B"
727+
728+
elif checkpoint[target_key].shape[0] == 1536:
720729
model_type = "wan-t2v-1.3B"
721730
elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
722731
model_type = "wan-t2v-14B"
@@ -3132,6 +3141,9 @@ def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
31323141
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
31333142
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
31343143
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
3144+
# For the VACE model
3145+
"before_proj": "proj_in",
3146+
"after_proj": "proj_out",
31353147
}
31363148

31373149
for key in list(checkpoint.keys()):

tests/quantization/gguf/test_gguf.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
HiDreamImageTransformer2DModel,
1616
SD3Transformer2DModel,
1717
StableDiffusion3Pipeline,
18+
WanTransformer3DModel,
19+
WanVACETransformer3DModel,
1820
)
1921
from diffusers.utils import load_image
2022
from diffusers.utils.testing_utils import (
@@ -577,3 +579,71 @@ def get_dummy_inputs(self):
577579
).to(torch_device, self.torch_dtype),
578580
"timesteps": torch.tensor([1]).to(torch_device, self.torch_dtype),
579581
}
582+
583+
584+
class WanGGUFTexttoVideoSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
585+
ckpt_path = "https://huggingface.co/city96/Wan2.1-T2V-14B-gguf/blob/main/wan2.1-t2v-14b-Q3_K_S.gguf"
586+
torch_dtype = torch.bfloat16
587+
model_cls = WanTransformer3DModel
588+
expected_memory_use_in_gb = 9
589+
590+
def get_dummy_inputs(self):
591+
return {
592+
"hidden_states": torch.randn((1, 36, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
593+
torch_device, self.torch_dtype
594+
),
595+
"encoder_hidden_states": torch.randn(
596+
(1, 512, 4096),
597+
generator=torch.Generator("cpu").manual_seed(0),
598+
).to(torch_device, self.torch_dtype),
599+
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
600+
}
601+
602+
603+
class WanGGUFImagetoVideoSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
604+
ckpt_path = "https://huggingface.co/city96/Wan2.1-I2V-14B-480P-gguf/blob/main/wan2.1-i2v-14b-480p-Q3_K_S.gguf"
605+
torch_dtype = torch.bfloat16
606+
model_cls = WanTransformer3DModel
607+
expected_memory_use_in_gb = 9
608+
609+
def get_dummy_inputs(self):
610+
return {
611+
"hidden_states": torch.randn((1, 36, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
612+
torch_device, self.torch_dtype
613+
),
614+
"encoder_hidden_states": torch.randn(
615+
(1, 512, 4096),
616+
generator=torch.Generator("cpu").manual_seed(0),
617+
).to(torch_device, self.torch_dtype),
618+
"encoder_hidden_states_image": torch.randn(
619+
(1, 257, 1280), generator=torch.Generator("cpu").manual_seed(0)
620+
).to(torch_device, self.torch_dtype),
621+
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
622+
}
623+
624+
625+
class WanVACEGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
626+
ckpt_path = "https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/blob/main/Wan2.1_14B_VACE-Q3_K_S.gguf"
627+
torch_dtype = torch.bfloat16
628+
model_cls = WanVACETransformer3DModel
629+
expected_memory_use_in_gb = 9
630+
631+
def get_dummy_inputs(self):
632+
return {
633+
"hidden_states": torch.randn((1, 16, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
634+
torch_device, self.torch_dtype
635+
),
636+
"encoder_hidden_states": torch.randn(
637+
(1, 512, 4096),
638+
generator=torch.Generator("cpu").manual_seed(0),
639+
).to(torch_device, self.torch_dtype),
640+
"control_hidden_states": torch.randn(
641+
(1, 96, 2, 64, 64),
642+
generator=torch.Generator("cpu").manual_seed(0),
643+
).to(torch_device, self.torch_dtype),
644+
"control_hidden_states_scale": torch.randn(
645+
(8,),
646+
generator=torch.Generator("cpu").manual_seed(0),
647+
).to(torch_device, self.torch_dtype),
648+
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
649+
}

0 commit comments

Comments
 (0)