Skip to content

Commit ef5cb71

Browse files
authored
chore: simplify parallelism dispatch (#616)
* chore: simplify parallelism dispatch * chore: simplify parallelism dispatch * chore: simplify parallelism dispatch * chore: simplify parallelism dispatch * chore: simplify parallelism dispatch * chore: simplify parallelism dispatch * chore: simplify parallelism dispatch * chore: simplify parallelism dispatch * chore: simplify parallelism dispatch * chore: simplify parallelism dispatch * chore: simplify parallelism dispatch
1 parent d958b77 commit ef5cb71

File tree

81 files changed

+431
-499
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

81 files changed

+431
-499
lines changed

src/cache_dit/caching/cache_adapters/cache_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,7 @@ def _release_pipeline_params(pipe):
646646
cls.release_hooks(pipe_or_adapter, remove_stats, remove_stats, remove_stats)
647647

648648
# maybe release parallelism stats
649-
from cache_dit.parallelism.parallel_interface import (
649+
from cache_dit.parallelism import (
650650
remove_parallelism_stats,
651651
)
652652

Lines changed: 154 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,154 @@
1-
from cache_dit.parallelism.parallel_backend import ParallelismBackend
2-
from cache_dit.parallelism.parallel_config import ParallelismConfig
3-
from cache_dit.parallelism.parallel_interface import enable_parallelism
4-
from cache_dit.parallelism.parallel_interface import maybe_pad_prompt
1+
import torch
2+
from diffusers.models.modeling_utils import ModelMixin
3+
from .backend import ParallelismBackend
4+
from .config import ParallelismConfig
5+
from cache_dit.utils import maybe_empty_cache
6+
from cache_dit.logger import init_logger
7+
from cache_dit.envs import ENV
8+
9+
10+
logger = init_logger(__name__)
11+
12+
13+
def enable_parallelism(
14+
transformer: torch.nn.Module | ModelMixin,
15+
parallelism_config: ParallelismConfig,
16+
) -> torch.nn.Module:
17+
assert isinstance(transformer, (torch.nn.Module, ModelMixin)), (
18+
"transformer must be an instance of torch.nn.Module or ModelMixin, "
19+
f"but got {type(transformer)}"
20+
)
21+
if getattr(transformer, "_is_parallelized", False):
22+
logger.warning("The transformer is already parallelized. Skipping parallelism enabling.")
23+
return transformer
24+
25+
# Parallelize Transformer: The check of parallelism backend is only for transformer
26+
# here. Text Encoder and VAE does not have different parallelism backends now.
27+
from .transformers import maybe_enable_parallelism_for_transformer
28+
29+
transformer = maybe_enable_parallelism_for_transformer(
30+
transformer=transformer,
31+
parallelism_config=parallelism_config,
32+
)
33+
# Set attention backend for both context parallelism and tensor parallelism if the
34+
# transformer is from diffusers and supports setting attention backend.
35+
_maybe_set_module_attention_backend(
36+
module=transformer,
37+
parallelism_config=parallelism_config,
38+
)
39+
40+
# Check text encoder and VAE for extra parallel modules
41+
extra_parallel_modules: list[torch.nn.Module] = []
42+
if parallelism_config.parallel_kwargs is not None:
43+
extra_parallel_modules = parallelism_config.parallel_kwargs.get(
44+
"extra_parallel_modules", []
45+
)
46+
47+
if extra_parallel_modules:
48+
for module in extra_parallel_modules:
49+
# Enable parallelism for text encoder
50+
if _is_text_encoder(module) and not _is_parallelized(module):
51+
from .text_encoders import (
52+
maybe_enable_parallelism_for_text_encoder,
53+
)
54+
55+
maybe_enable_parallelism_for_text_encoder(
56+
text_encoder=module,
57+
parallelism_config=parallelism_config,
58+
)
59+
# Enable parallelism for ControlNet
60+
elif _is_controlnet(module) and not _is_parallelized(module):
61+
from .controlnets import (
62+
maybe_enable_parallelism_for_controlnet,
63+
)
64+
65+
maybe_enable_parallelism_for_controlnet(
66+
controlnet=module,
67+
parallelism_config=parallelism_config,
68+
)
69+
_maybe_set_module_attention_backend(
70+
module=module,
71+
parallelism_config=parallelism_config,
72+
)
73+
# Enable parallelism for VAE
74+
elif _is_vae(module) and not _is_parallelized(module):
75+
logger.warning("Parallelism for VAE is not supported yet. Skipped!")
76+
77+
# NOTE: Workaround for potential memory peak issue after parallelism
78+
# enabling, specially for tensor parallelism in native pytorch backend.
79+
maybe_empty_cache()
80+
81+
return transformer
82+
83+
84+
def remove_parallelism_stats(
85+
module: torch.nn.Module,
86+
) -> torch.nn.Module:
87+
if not getattr(module, "_is_parallelized", False):
88+
logger.warning("The transformer is not parallelized. Skipping removing parallelism.")
89+
return module
90+
91+
if hasattr(module, "_is_parallelized"):
92+
del module._is_parallelized # type: ignore[attr-defined]
93+
if hasattr(module, "_parallelism_config"):
94+
del module._parallelism_config # type: ignore[attr-defined]
95+
return module
96+
97+
98+
# Some helper functions for parallelism enabling
99+
def _maybe_set_module_attention_backend(
100+
module: torch.nn.Module | ModelMixin,
101+
parallelism_config: ParallelismConfig,
102+
) -> None:
103+
# Set attention backend for both context parallelism and tensor parallelism if the
104+
# transformer is from diffusers and supports setting attention backend.
105+
module_cls_name = module.__class__.__name__
106+
if hasattr(module, "set_attention_backend") and isinstance(module, ModelMixin):
107+
attention_backend = parallelism_config.parallel_kwargs.get("attention_backend", None)
108+
# native, _native_cudnn, flash, etc.
109+
if attention_backend is None:
110+
# Default to native for context parallelism due to:
111+
# - attn mask support (re-registered in cache-dit)
112+
# - general compatibility with various models
113+
# NOTE: We only set default attention backend for NATIVE_DIFFUSER backend here
114+
# while using context parallelism. For other backends, we do not change the
115+
# attention backend if it is None.
116+
if parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER:
117+
module.set_attention_backend("native")
118+
logger.warning(
119+
"attention_backend is None, set default attention backend of "
120+
f"{module_cls_name} to native for context parallelism."
121+
)
122+
else:
123+
# Ensure custom attention backends are registered in cache-dit.
124+
if not ENV.CACHE_DIT_ENABLE_CUSTOM_ATTN_ALREADY_DISPATCH:
125+
from .attention import (
126+
_maybe_register_custom_attn_backends,
127+
)
128+
129+
_maybe_register_custom_attn_backends()
130+
131+
module.set_attention_backend(attention_backend)
132+
logger.info(
133+
"Found attention_backend from config, set attention backend of "
134+
f"{module_cls_name} to: {attention_backend}."
135+
)
136+
137+
138+
def _is_text_encoder(module: torch.nn.Module) -> bool:
139+
_import_module = module.__class__.__module__
140+
return _import_module.startswith("transformers")
141+
142+
143+
def _is_controlnet(module: torch.nn.Module) -> bool:
144+
_import_module = module.__class__.__module__
145+
return _import_module.startswith("diffusers.models.controlnet")
146+
147+
148+
def _is_vae(module: torch.nn.Module) -> bool:
149+
_import_module = module.__class__.__module__
150+
return _import_module.startswith("diffusers.models.autoencoder")
151+
152+
153+
def _is_parallelized(module: torch.nn.Module) -> bool:
154+
return getattr(module, "_is_parallelized", False)
File renamed without changes.

src/cache_dit/parallelism/vae/native_pytorch/__init__.py renamed to src/cache_dit/parallelism/autoencoders/data_parallelism/__init__.py

File renamed without changes.

src/cache_dit/parallelism/vae/native_pytorch/data_parallelism/dp_plan_autoencoder_kl.py renamed to src/cache_dit/parallelism/autoencoders/data_parallelism/dp_plan_autoencoder_kl.py

File renamed without changes.

src/cache_dit/parallelism/vae/native_pytorch/data_parallelism/dp_plan_autoencoder_kl_flux2.py renamed to src/cache_dit/parallelism/autoencoders/data_parallelism/dp_plan_autoencoder_kl_flux2.py

File renamed without changes.

src/cache_dit/parallelism/vae/native_pytorch/data_parallelism/dp_plan_autoencoder_kl_qwen_image.py renamed to src/cache_dit/parallelism/autoencoders/data_parallelism/dp_plan_autoencoder_kl_qwen_image.py

File renamed without changes.

src/cache_dit/parallelism/vae/native_pytorch/data_parallelism/dp_plan_registers.py renamed to src/cache_dit/parallelism/autoencoders/data_parallelism/dp_plan_registers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
from abc import abstractmethod
44
from typing import Dict
5-
from cache_dit.parallelism.parallel_config import ParallelismConfig
5+
from cache_dit.parallelism.config import ParallelismConfig
66
from cache_dit.logger import init_logger
77

88
logger = init_logger(__name__)

src/cache_dit/parallelism/vae/native_pytorch/data_parallelism/dp_planners.py renamed to src/cache_dit/parallelism/autoencoders/data_parallelism/dp_planners.py

File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)