|
19 | 19 | NunchakuQwenImageNaiveFA2Processor, |
20 | 20 | NunchakuQwenImageTransformer2DModel, |
21 | 21 | ) |
| 22 | + from nunchaku.models.transformers.transformer_zimage import ( |
| 23 | + NunchakuZImageTransformer2DModel, |
| 24 | + NunchakuZSingleStreamAttnProcessor, |
| 25 | + NunchakuZImageAttention, |
| 26 | + ) |
22 | 27 | except ImportError: |
23 | 28 | raise ImportError( |
24 | | - "NunchakuFluxTransformer2DModelV2 or NunchakuQwenImageTransformer2DModel " |
25 | | - "requires the 'nunchaku' package. Please install nunchaku before using " |
26 | | - "the context parallelism for nunchaku 4-bits models." |
| 29 | + "NunchakuZImageTransformer2DModel, NunchakuFluxTransformer2DModelV2 and " |
| 30 | + "NunchakuQwenImageTransformer2DModel requires the 'nunchaku' package. " |
| 31 | + "Please install nunchaku>=1.10 before using the context parallelism for " |
| 32 | + "nunchaku 4-bits models." |
27 | 33 | ) |
28 | 34 |
|
29 | 35 | try: |
|
43 | 49 | ContextParallelismPlannerRegister, |
44 | 50 | ) |
45 | 51 |
|
| 52 | +from cache_dit.parallelism.attention import _maybe_patch_find_submodule |
46 | 53 | from cache_dit.logger import init_logger |
47 | 54 |
|
48 | 55 | logger = init_logger(__name__) |
@@ -383,3 +390,139 @@ def __patch_NunchakuQwenImageNaiveFA2Processor__call__( |
383 | 390 | txt_attn_output = attn.to_add_out(txt_attn_output) |
384 | 391 |
|
385 | 392 | return img_attn_output, txt_attn_output |
| 393 | + |
| 394 | + |
| 395 | +@ContextParallelismPlannerRegister.register("NunchakuZImageTransformer2DModel") |
| 396 | +class NunchakuZImageContextParallelismPlanner(ContextParallelismPlanner): |
| 397 | + def apply( |
| 398 | + self, |
| 399 | + transformer: Optional[torch.nn.Module | ModelMixin] = None, |
| 400 | + **kwargs, |
| 401 | + ) -> ContextParallelModelPlan: |
| 402 | + |
| 403 | + # NOTE: Diffusers native CP plan still not supported for ZImageTransformer2DModel |
| 404 | + self._cp_planner_preferred_native_diffusers = False |
| 405 | + |
| 406 | + if transformer is not None and self._cp_planner_preferred_native_diffusers: |
| 407 | + assert isinstance( |
| 408 | + transformer, NunchakuZImageTransformer2DModel |
| 409 | + ), "Transformer must be an instance of NunchakuZImageTransformer2DModel" |
| 410 | + if hasattr(transformer, "_cp_plan"): |
| 411 | + if transformer._cp_plan is not None: |
| 412 | + return transformer._cp_plan |
| 413 | + |
| 414 | + # NOTE: This only a temporary workaround for ZImage to make context parallelism |
| 415 | + # work compatible with DBCache FnB0. The better way is to make DBCache fully |
| 416 | + # compatible with diffusers native context parallelism, e.g., check the split/gather |
| 417 | + # hooks in each block/layer in the initialization of DBCache. |
| 418 | + # Issue: https://github.com/vipshop/cache-dit/issues/498 |
| 419 | + _maybe_patch_find_submodule() |
| 420 | + if not hasattr(NunchakuZSingleStreamAttnProcessor, "_parallel_config"): |
| 421 | + NunchakuZSingleStreamAttnProcessor._parallel_config = None |
| 422 | + if not hasattr(NunchakuZSingleStreamAttnProcessor, "_attention_backend"): |
| 423 | + NunchakuZSingleStreamAttnProcessor._attention_backend = None |
| 424 | + if not hasattr(NunchakuZImageAttention, "_parallel_config"): |
| 425 | + NunchakuZImageAttention._parallel_config = None |
| 426 | + if not hasattr(NunchakuZImageAttention, "_attention_backend"): |
| 427 | + NunchakuZImageAttention._attention_backend = None |
| 428 | + |
| 429 | + n_noise_refiner_layers = len(transformer.noise_refiner) # 2 |
| 430 | + n_context_refiner_layers = len(transformer.context_refiner) # 2 |
| 431 | + n_layers = len(transformer.layers) # 30 |
| 432 | + # controlnet layer idx: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28] |
| 433 | + # num_controlnet_samples = len(transformer.layers) // 2 # 15 |
| 434 | + has_controlnet = kwargs.get("has_controlnet", None) |
| 435 | + if not has_controlnet: |
| 436 | + # cp plan for ZImageTransformer2DModel if no controlnet |
| 437 | + _cp_plan = { |
| 438 | + # 0. Hooks for noise_refiner layers, 2 |
| 439 | + "noise_refiner.0": { |
| 440 | + "x": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), |
| 441 | + }, |
| 442 | + "noise_refiner.*": { |
| 443 | + "freqs_cis": ContextParallelInput( |
| 444 | + split_dim=1, expected_dims=3, split_output=False |
| 445 | + ), |
| 446 | + }, |
| 447 | + f"noise_refiner.{n_noise_refiner_layers - 1}": ContextParallelOutput( |
| 448 | + gather_dim=1, expected_dims=3 |
| 449 | + ), |
| 450 | + # 1. Hooks for context_refiner layers, 2 |
| 451 | + "context_refiner.0": { |
| 452 | + "x": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), |
| 453 | + }, |
| 454 | + "context_refiner.*": { |
| 455 | + "freqs_cis": ContextParallelInput( |
| 456 | + split_dim=1, expected_dims=3, split_output=False |
| 457 | + ), |
| 458 | + }, |
| 459 | + f"context_refiner.{n_context_refiner_layers - 1}": ContextParallelOutput( |
| 460 | + gather_dim=1, expected_dims=3 |
| 461 | + ), |
| 462 | + # 2. Hooks for main transformer layers, num_layers=30 |
| 463 | + "layers.0": { |
| 464 | + "x": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), |
| 465 | + }, |
| 466 | + "layers.*": { |
| 467 | + "freqs_cis": ContextParallelInput( |
| 468 | + split_dim=1, expected_dims=3, split_output=False |
| 469 | + ), |
| 470 | + }, |
| 471 | + # NEED: call _maybe_patch_find_submodule to support ModuleDict like 'all_final_layer' |
| 472 | + "all_final_layer": ContextParallelOutput(gather_dim=1, expected_dims=3), |
| 473 | + # NOTE: The 'all_final_layer' is a ModuleDict of several final layers, |
| 474 | + # each for a specific patch size combination, so we do not add hooks for it here. |
| 475 | + # So, we have to gather the output of the last transformer layer. |
| 476 | + # f"layers.{num_layers - 1}": ContextParallelOutput(gather_dim=1, expected_dims=3), |
| 477 | + } |
| 478 | + else: |
| 479 | + # Special cp plan for NunchakuZImageTransformer2DModel with ZImageControlNetModel |
| 480 | + logger.warning( |
| 481 | + "Using special context parallelism plan for NunchakuZImageTransformer2DModel " |
| 482 | + "due to the 'has_controlnet' flag is set to True." |
| 483 | + ) |
| 484 | + _cp_plan = { |
| 485 | + # zimage controlnet shared the same refiner as zimage, so, we need to |
| 486 | + # add gather hooks for all layers in noise_refiner and context_refiner. |
| 487 | + # 0. Hooks for noise_refiner layers, 2 |
| 488 | + # Insert gather hook after each layers due to the ops: (controlnet) |
| 489 | + # - x = x + noise_refiner_block_samples[layer_idx] |
| 490 | + "noise_refiner.*": { |
| 491 | + "x": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), |
| 492 | + "freqs_cis": ContextParallelInput( |
| 493 | + split_dim=1, expected_dims=3, split_output=False |
| 494 | + ), |
| 495 | + }, |
| 496 | + **{ |
| 497 | + f"noise_refiner.{i}": ContextParallelOutput(gather_dim=1, expected_dims=3) |
| 498 | + for i in range(n_noise_refiner_layers) |
| 499 | + }, |
| 500 | + # 1. Hooks for context_refiner layers, 2 |
| 501 | + "context_refiner.0": { |
| 502 | + "x": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), |
| 503 | + }, |
| 504 | + "context_refiner.*": { |
| 505 | + "freqs_cis": ContextParallelInput( |
| 506 | + split_dim=1, expected_dims=3, split_output=False |
| 507 | + ), |
| 508 | + }, |
| 509 | + f"context_refiner.{n_context_refiner_layers - 1}": ContextParallelOutput( |
| 510 | + gather_dim=1, expected_dims=3 |
| 511 | + ), |
| 512 | + # 2. Hooks for main transformer layers, num_layers=30 |
| 513 | + # Insert gather hook after each layers due to the ops: (main transformer) |
| 514 | + # - unified + controlnet_block_samples[layer_idx] |
| 515 | + "layers.*": { |
| 516 | + "x": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), |
| 517 | + "freqs_cis": ContextParallelInput( |
| 518 | + split_dim=1, expected_dims=3, split_output=False |
| 519 | + ), |
| 520 | + }, |
| 521 | + **{ |
| 522 | + f"layers.{i}": ContextParallelOutput(gather_dim=1, expected_dims=3) |
| 523 | + for i in range(n_layers) |
| 524 | + }, |
| 525 | + # NEED: call _maybe_patch_find_submodule to support ModuleDict like 'all_final_layer' |
| 526 | + "all_final_layer": ContextParallelOutput(gather_dim=1, expected_dims=3), |
| 527 | + } |
| 528 | + return _cp_plan |
0 commit comments