Skip to content

Commit 44ee074

Browse files
authored
feat: support ⚡️Z-Image-Turbo Nunchaku (#623)
* feat: support Z-Image Nunchaku * feat: support Z-Image Nunchaku * feat: support Z-Image Nunchaku * feat: support Z-Image Nunchaku * feat: support Z-Image Nunchaku * feat: support Z-Image Nunchaku
1 parent 8426e94 commit 44ee074

File tree

6 files changed

+203
-27
lines changed

6 files changed

+203
-27
lines changed

README.md

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,13 @@ You can install the stable release of cache-dit from PyPI, or the latest develop
7676
## 🔥Supported DiTs
7777

7878
> [!Tip]
79-
> One Model Series may contain many pipelines. cache-dit applies optimizations at the Transformer level; thus, any pipelines that include the supported transformer are already supported by cache-dit. ✅: supported now; ✖️: not supported now; **[`Q`](https://github.com/nunchaku-tech/nunchaku)**: [nunchaku](https://github.com/nunchaku-tech/nunchaku); **[C-P](./)**: Context Parallelism; **[T-P](./)**: Tensor Parallelism; **[TE-P](./)**: Text Encoder Parallelism; **[CN-P](./)**: ControlNet Parallelism; **[VAE-P](./)**: VAE Parallelism (TODO).
79+
> One Model Series may contain many pipelines. cache-dit applies optimizations at the Transformer level; thus, any pipelines that include the supported transformer are already supported by cache-dit. ✅: supported now; ✖️: not supported now; **[C-P](./)**: Context Parallelism; **[T-P](./)**: Tensor Parallelism; **[TE-P](./)**: Text Encoder Parallelism; **[CN-P](./)**: ControlNet Parallelism; **[VAE-P](./)**: VAE Parallelism (TODO).
8080
8181
<div align="center">
8282

8383
| 📚Supported DiTs: `🤗65+` | Cache | C-P | T-P | TE-P | CN-P | VAE-P |
8484
|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
85+
| Z-Image-Turbo `⚡️Nunchaku` ||| ✖️ || ✖️ | ✖️ |
8586
| Qwen-Image-Layered ||||| ✖️ | ✖️ |
8687
| Qwen-Image-Edit-2511-Lightning ||||| ✖️ | ✖️ |
8788
| Qwen-Image-Edit-2511 ||||| ✖️ | ✖️ |
@@ -113,14 +114,14 @@ You can install the stable release of cache-dit from PyPI, or the latest develop
113114
| HunyuanImage-2.1 ||||| ✖️ | ✖️ |
114115
| HunyuanVideo-1.5 || ✖️ | ✖️ || ✖️ | ✖️ |
115116
| HunyuanVideo ||||| ✖️ | ✖️ |
116-
| FLUX.1-dev `Q` ||| ✖️ || ✖️ | ✖️ |
117-
| FLUX.1-Fill-dev `Q` ||| ✖️ || ✖️ | ✖️ |
118-
| Qwen-Image `Q` ||| ✖️ || ✖️ | ✖️ |
119-
| Qwen-Image-Edit `Q` ||| ✖️ || ✖️ | ✖️ |
120-
| Qwen-Image-Edit-2509 `Q` ||| ✖️ || ✖️ | ✖️ |
121-
| Qwen-Image-Lightning `Q` ||| ✖️ || ✖️ | ✖️ |
122-
| Qwen-Image-Edit-Lightning `Q` ||| ✖️ || ✖️ | ✖️ |
123-
| Qwen-Image-Edit-2509-Lightning `Q` ||| ✖️ || ✖️ | ✖️ |
117+
| FLUX.1-dev `⚡️Nunchaku` ||| ✖️ || ✖️ | ✖️ |
118+
| FLUX.1-Fill-dev `⚡️Nunchaku` ||| ✖️ || ✖️ | ✖️ |
119+
| Qwen-Image `⚡️Nunchaku` ||| ✖️ || ✖️ | ✖️ |
120+
| Qwen-Image-Edit `⚡️Nunchaku` ||| ✖️ || ✖️ | ✖️ |
121+
| Qwen-Image-Edit-2509 `⚡️Nunchaku` ||| ✖️ || ✖️ | ✖️ |
122+
| Qwen-Image-Lightning `⚡️Nunchaku` ||| ✖️ || ✖️ | ✖️ |
123+
| Qwen...Edit-Lightning `⚡️Nunchaku` ||| ✖️ || ✖️ | ✖️ |
124+
| Qwen...Edit-2509-Lightning `⚡️Nunchaku` ||| ✖️ || ✖️ | ✖️ |
124125
| SkyReels-V2-T2V ||||| ✖️ | ✖️ |
125126
| LongCat-Video || ✖️ | ✖️ || ✖️ | ✖️ |
126127
| ChronoEdit-14B ||||| ✖️ | ✖️ |

examples/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ python3 generate.py list # list all available examples
5858
[generate.py:53] - ✅ wan2.2_vace - Defalut: linoyts/Wan2.2-VACE-Fun-14B-diffusers
5959
[generate.py:53] - ✅ wan2.1_vace - Defalut: Wan-AI/Wan2.1-VACE-1.3B-diffusers
6060
[generate.py:53] - ✅ ovis_image - Defalut: AIDC-AI/Ovis-Image-7B
61+
[generate.py:53] - ✅ zimage_nunchaku - Defalut: nunchaku/nunchaku-z-image-turbo
6162
[generate.py:53] - ✅ zimage - Defalut: Tongyi-MAI/Z-Image-Turbo
6263
[generate.py:53] - ✅ zimage_controlnet_2.0 - Defalut: alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0
6364
[generate.py:53] - ✅ zimage_controlnet_2.1 - Defalut: alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.1
@@ -81,6 +82,7 @@ python3 generate.py qwen_image
8182
python3 generate.py skyreels_v2
8283
python3 generate.py wan2.2
8384
python3 generate.py zimage
85+
python3 generate.py zimage_nunchaku
8486
python3 generate.py zimage_controlnet_2.1
8587
python3 generate.py generate longcat_image
8688
python3 generate.py generate longcat_image_edit

examples/registers.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
"WAN_VACE_DIR": "Wan-AI/Wan2.1-VACE-1.3B-diffusers",
6565
"WAN_2_2_VACE_DIR": "linoyts/Wan2.2-VACE-Fun-14B-diffusers",
6666
"ZIMAGE_DIR": "Tongyi-MAI/Z-Image-Turbo",
67+
"NUNCHAKU_ZIMAGE_DIR": "nunchaku-tech/nunchaku-z-image-turbo",
6768
"Z_IMAGE_CONTROLNET_2_1_DIR": "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.1",
6869
"Z_IMAGE_CONTROLNET_2_0_DIR": "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
6970
"LONGCAT_IMAGE_DIR": "meituan-longcat/LongCat-Image",
@@ -804,20 +805,36 @@ def _zimage_turbo_steps_mask(
804805

805806

806807
@ExampleRegister.register("zimage", default="Tongyi-MAI/Z-Image-Turbo")
808+
@ExampleRegister.register("zimage_nunchaku", default="nunchaku/nunchaku-z-image-turbo")
807809
def zimage_example(args: argparse.Namespace, **kwargs) -> Example:
808810
from diffusers import ZImagePipeline
809811

810812
if args.cache:
811813
# Only warmup 4 steps (total 9 steps) for distilled models
812814
args.max_warmup_steps = min(4, args.max_warmup_steps)
813815

816+
if "nunchaku" in args.example.lower():
817+
from nunchaku import NunchakuZImageTransformer2DModel
818+
819+
nunchaku_zimage_dir = _path(
820+
"nunchaku-tech/nunchaku-z-image-turbo",
821+
args=args,
822+
transformer=True,
823+
)
824+
transformer = NunchakuZImageTransformer2DModel.from_pretrained(
825+
f"{nunchaku_zimage_dir}/svdq-int4_r128-z-image-turbo.safetensors"
826+
)
827+
else:
828+
transformer = None
829+
814830
steps_computation_mask = _zimage_turbo_steps_mask(args)
815831
return Example(
816832
args=args,
817833
init_config=ExampleInitConfig(
818834
task_type=ExampleType.T2I, # Text to Image
819835
model_name_or_path=_path("Tongyi-MAI/Z-Image-Turbo"),
820836
pipeline_class=ZImagePipeline,
837+
transformer=transformer, # maybe use Nunchaku zimage transformer
821838
bnb_4bit_components=["text_encoder"],
822839
extra_optimize_kwargs={
823840
"steps_computation_mask": steps_computation_mask,

src/cache_dit/parallelism/transformers/context_parallelism/__init__.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,6 @@ def _maybe_patch_native_parallel_config(
101101
if not cls_name.startswith("Nunchaku"):
102102
return transformer
103103

104-
from diffusers import FluxTransformer2DModel, QwenImageTransformer2DModel
105-
106104
try:
107105
from nunchaku.models.transformers.transformer_flux_v2 import (
108106
NunchakuFluxTransformer2DModelV2,
@@ -114,42 +112,54 @@ def _maybe_patch_native_parallel_config(
114112
NunchakuQwenImageNaiveFA2Processor,
115113
NunchakuQwenImageTransformer2DModel,
116114
)
115+
from nunchaku.models.transformers.transformer_zimage import (
116+
NunchakuZImageTransformer2DModel,
117+
NunchakuZSingleStreamAttnProcessor,
118+
NunchakuZImageAttention,
119+
)
117120
except ImportError:
118121
raise ImportError(
119-
"NunchakuFluxTransformer2DModelV2 or NunchakuQwenImageTransformer2DModel "
120-
"requires the 'nunchaku' package. Please install nunchaku before using "
121-
"the context parallelism for nunchaku 4-bits models."
122+
"NunchakuZImageTransformer2DModel, NunchakuFluxTransformer2DModelV2 and "
123+
"NunchakuQwenImageTransformer2DModel requires the 'nunchaku' package. "
124+
"Please install nunchaku>=1.10 before using the context parallelism for "
125+
"nunchaku 4-bits models."
122126
)
127+
123128
assert isinstance(
124129
transformer,
125130
(
126131
NunchakuFluxTransformer2DModelV2,
127-
FluxTransformer2DModel,
128-
),
129-
) or isinstance(
130-
transformer,
131-
(
132132
NunchakuQwenImageTransformer2DModel,
133-
QwenImageTransformer2DModel,
133+
NunchakuZImageTransformer2DModel,
134134
),
135-
), (
136-
"transformer must be an instance of NunchakuFluxTransformer2DModelV2 "
137-
f"or NunchakuQwenImageTransformer2DModel, but got {type(transformer)}"
138135
)
139-
config = transformer._parallel_config
136+
config = getattr(transformer, "_parallel_config", None)
137+
if config is None:
138+
raise logger.warning(
139+
f"The transformer {cls_name} does not have _parallel_config attribute. "
140+
"Skipping patching native parallel config."
141+
)
140142

141143
attention_classes = (
142144
NunchakuFluxAttention,
143145
NunchakuFluxFA2Processor,
144146
NunchakuQwenAttention,
145147
NunchakuQwenImageNaiveFA2Processor,
148+
NunchakuZImageAttention,
149+
NunchakuZSingleStreamAttnProcessor,
146150
)
147151
for module in transformer.modules():
148152
if not isinstance(module, attention_classes):
149153
continue
150154
processor = getattr(module, "processor", None)
151155
if processor is None or not hasattr(processor, "_parallel_config"):
152156
continue
157+
if getattr(processor, "_parallel_config", None) is not None:
158+
logger.warning(
159+
f"The attention processor {processor.__class__.__name__} already has "
160+
"_parallel_config attribute set. Skipping patching native parallel config."
161+
)
162+
continue
153163
processor._parallel_config = config
154164

155165
return transformer

src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_nunchaku.py

Lines changed: 146 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,17 @@
1919
NunchakuQwenImageNaiveFA2Processor,
2020
NunchakuQwenImageTransformer2DModel,
2121
)
22+
from nunchaku.models.transformers.transformer_zimage import (
23+
NunchakuZImageTransformer2DModel,
24+
NunchakuZSingleStreamAttnProcessor,
25+
NunchakuZImageAttention,
26+
)
2227
except ImportError:
2328
raise ImportError(
24-
"NunchakuFluxTransformer2DModelV2 or NunchakuQwenImageTransformer2DModel "
25-
"requires the 'nunchaku' package. Please install nunchaku before using "
26-
"the context parallelism for nunchaku 4-bits models."
29+
"NunchakuZImageTransformer2DModel, NunchakuFluxTransformer2DModelV2 and "
30+
"NunchakuQwenImageTransformer2DModel requires the 'nunchaku' package. "
31+
"Please install nunchaku>=1.10 before using the context parallelism for "
32+
"nunchaku 4-bits models."
2733
)
2834

2935
try:
@@ -43,6 +49,7 @@
4349
ContextParallelismPlannerRegister,
4450
)
4551

52+
from cache_dit.parallelism.attention import _maybe_patch_find_submodule
4653
from cache_dit.logger import init_logger
4754

4855
logger = init_logger(__name__)
@@ -383,3 +390,139 @@ def __patch_NunchakuQwenImageNaiveFA2Processor__call__(
383390
txt_attn_output = attn.to_add_out(txt_attn_output)
384391

385392
return img_attn_output, txt_attn_output
393+
394+
395+
@ContextParallelismPlannerRegister.register("NunchakuZImageTransformer2DModel")
396+
class NunchakuZImageContextParallelismPlanner(ContextParallelismPlanner):
397+
def apply(
398+
self,
399+
transformer: Optional[torch.nn.Module | ModelMixin] = None,
400+
**kwargs,
401+
) -> ContextParallelModelPlan:
402+
403+
# NOTE: Diffusers native CP plan still not supported for ZImageTransformer2DModel
404+
self._cp_planner_preferred_native_diffusers = False
405+
406+
if transformer is not None and self._cp_planner_preferred_native_diffusers:
407+
assert isinstance(
408+
transformer, NunchakuZImageTransformer2DModel
409+
), "Transformer must be an instance of NunchakuZImageTransformer2DModel"
410+
if hasattr(transformer, "_cp_plan"):
411+
if transformer._cp_plan is not None:
412+
return transformer._cp_plan
413+
414+
# NOTE: This only a temporary workaround for ZImage to make context parallelism
415+
# work compatible with DBCache FnB0. The better way is to make DBCache fully
416+
# compatible with diffusers native context parallelism, e.g., check the split/gather
417+
# hooks in each block/layer in the initialization of DBCache.
418+
# Issue: https://github.com/vipshop/cache-dit/issues/498
419+
_maybe_patch_find_submodule()
420+
if not hasattr(NunchakuZSingleStreamAttnProcessor, "_parallel_config"):
421+
NunchakuZSingleStreamAttnProcessor._parallel_config = None
422+
if not hasattr(NunchakuZSingleStreamAttnProcessor, "_attention_backend"):
423+
NunchakuZSingleStreamAttnProcessor._attention_backend = None
424+
if not hasattr(NunchakuZImageAttention, "_parallel_config"):
425+
NunchakuZImageAttention._parallel_config = None
426+
if not hasattr(NunchakuZImageAttention, "_attention_backend"):
427+
NunchakuZImageAttention._attention_backend = None
428+
429+
n_noise_refiner_layers = len(transformer.noise_refiner) # 2
430+
n_context_refiner_layers = len(transformer.context_refiner) # 2
431+
n_layers = len(transformer.layers) # 30
432+
# controlnet layer idx: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28]
433+
# num_controlnet_samples = len(transformer.layers) // 2 # 15
434+
has_controlnet = kwargs.get("has_controlnet", None)
435+
if not has_controlnet:
436+
# cp plan for ZImageTransformer2DModel if no controlnet
437+
_cp_plan = {
438+
# 0. Hooks for noise_refiner layers, 2
439+
"noise_refiner.0": {
440+
"x": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
441+
},
442+
"noise_refiner.*": {
443+
"freqs_cis": ContextParallelInput(
444+
split_dim=1, expected_dims=3, split_output=False
445+
),
446+
},
447+
f"noise_refiner.{n_noise_refiner_layers - 1}": ContextParallelOutput(
448+
gather_dim=1, expected_dims=3
449+
),
450+
# 1. Hooks for context_refiner layers, 2
451+
"context_refiner.0": {
452+
"x": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
453+
},
454+
"context_refiner.*": {
455+
"freqs_cis": ContextParallelInput(
456+
split_dim=1, expected_dims=3, split_output=False
457+
),
458+
},
459+
f"context_refiner.{n_context_refiner_layers - 1}": ContextParallelOutput(
460+
gather_dim=1, expected_dims=3
461+
),
462+
# 2. Hooks for main transformer layers, num_layers=30
463+
"layers.0": {
464+
"x": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
465+
},
466+
"layers.*": {
467+
"freqs_cis": ContextParallelInput(
468+
split_dim=1, expected_dims=3, split_output=False
469+
),
470+
},
471+
# NEED: call _maybe_patch_find_submodule to support ModuleDict like 'all_final_layer'
472+
"all_final_layer": ContextParallelOutput(gather_dim=1, expected_dims=3),
473+
# NOTE: The 'all_final_layer' is a ModuleDict of several final layers,
474+
# each for a specific patch size combination, so we do not add hooks for it here.
475+
# So, we have to gather the output of the last transformer layer.
476+
# f"layers.{num_layers - 1}": ContextParallelOutput(gather_dim=1, expected_dims=3),
477+
}
478+
else:
479+
# Special cp plan for NunchakuZImageTransformer2DModel with ZImageControlNetModel
480+
logger.warning(
481+
"Using special context parallelism plan for NunchakuZImageTransformer2DModel "
482+
"due to the 'has_controlnet' flag is set to True."
483+
)
484+
_cp_plan = {
485+
# zimage controlnet shared the same refiner as zimage, so, we need to
486+
# add gather hooks for all layers in noise_refiner and context_refiner.
487+
# 0. Hooks for noise_refiner layers, 2
488+
# Insert gather hook after each layers due to the ops: (controlnet)
489+
# - x = x + noise_refiner_block_samples[layer_idx]
490+
"noise_refiner.*": {
491+
"x": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
492+
"freqs_cis": ContextParallelInput(
493+
split_dim=1, expected_dims=3, split_output=False
494+
),
495+
},
496+
**{
497+
f"noise_refiner.{i}": ContextParallelOutput(gather_dim=1, expected_dims=3)
498+
for i in range(n_noise_refiner_layers)
499+
},
500+
# 1. Hooks for context_refiner layers, 2
501+
"context_refiner.0": {
502+
"x": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
503+
},
504+
"context_refiner.*": {
505+
"freqs_cis": ContextParallelInput(
506+
split_dim=1, expected_dims=3, split_output=False
507+
),
508+
},
509+
f"context_refiner.{n_context_refiner_layers - 1}": ContextParallelOutput(
510+
gather_dim=1, expected_dims=3
511+
),
512+
# 2. Hooks for main transformer layers, num_layers=30
513+
# Insert gather hook after each layers due to the ops: (main transformer)
514+
# - unified + controlnet_block_samples[layer_idx]
515+
"layers.*": {
516+
"x": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
517+
"freqs_cis": ContextParallelInput(
518+
split_dim=1, expected_dims=3, split_output=False
519+
),
520+
},
521+
**{
522+
f"layers.{i}": ContextParallelOutput(gather_dim=1, expected_dims=3)
523+
for i in range(n_layers)
524+
},
525+
# NEED: call _maybe_patch_find_submodule to support ModuleDict like 'all_final_layer'
526+
"all_final_layer": ContextParallelOutput(gather_dim=1, expected_dims=3),
527+
}
528+
return _cp_plan

src/cache_dit/parallelism/transformers/context_parallelism/cp_planners.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ def _activate_cp_planners():
9696
from .cp_plan_nunchaku import ( # noqa: F401
9797
NunchakuQwenImageContextParallelismPlanner,
9898
)
99+
from .cp_plan_nunchaku import ( # noqa: F401
100+
NunchakuZImageContextParallelismPlanner,
101+
)
99102

100103

101104
__all__ = ["_activate_cp_planners"]

0 commit comments

Comments
 (0)