Skip to content

Commit 088d909

Browse files
committed
apply review suggestions
1 parent 2aec312 commit 088d909

File tree

7 files changed

+112
-69
lines changed

7 files changed

+112
-69
lines changed

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@
189189
"CogView4Transformer2DModel",
190190
"ConsisIDTransformer3DModel",
191191
"ConsistencyDecoderVAE",
192+
"ContextParallelConfig",
192193
"ControlNetModel",
193194
"ControlNetUnionModel",
194195
"ControlNetXSAdapter",
@@ -862,6 +863,7 @@
862863
CogView4Transformer2DModel,
863864
ConsisIDTransformer3DModel,
864865
ConsistencyDecoderVAE,
866+
ContextParallelConfig,
865867
ControlNetModel,
866868
ControlNetUnionModel,
867869
ControlNetXSAdapter,

src/diffusers/hooks/context_parallel.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
import torch.distributed._functional_collectives as funcol
2121

2222
from ..models._modeling_parallel import (
23+
ContextParallelConfig,
2324
ContextParallelInput,
2425
ContextParallelModelPlan,
2526
ContextParallelOutput,
26-
_InternalParallelConfig,
2727
)
2828
from ..utils import get_logger
2929
from ..utils.torch_utils import unwrap_module
@@ -74,11 +74,11 @@ def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None)
7474

7575
def apply_context_parallel(
7676
module: torch.nn.Module,
77-
parallel_config: _InternalParallelConfig,
77+
parallel_config: ContextParallelConfig,
7878
plan: Dict[str, ContextParallelModelPlan],
7979
) -> None:
8080
"""Apply context parallel on a model."""
81-
logger.debug(f"Applying context parallel with CP mesh: {parallel_config.cp_mesh} and plan: {plan}")
81+
logger.debug(f"Applying context parallel with CP mesh: {parallel_config._mesh} and plan: {plan}")
8282

8383
for module_id, cp_model_plan in plan.items():
8484
submodule = _get_submodule_by_name(module, module_id)
@@ -122,7 +122,7 @@ def remove_context_parallel(module: torch.nn.Module, plan: Dict[str, ContextPara
122122

123123

124124
class ContextParallelSplitHook(ModelHook):
125-
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: _InternalParallelConfig) -> None:
125+
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None:
126126
super().__init__()
127127
self.metadata = metadata
128128
self.parallel_config = parallel_config
@@ -207,7 +207,7 @@ def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) ->
207207

208208

209209
class ContextParallelGatherHook(ModelHook):
210-
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: _InternalParallelConfig) -> None:
210+
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None:
211211
super().__init__()
212212
self.metadata = metadata
213213
self.parallel_config = parallel_config
@@ -251,7 +251,11 @@ def backward(ctx, grad_output):
251251
class EquipartitionSharder:
252252
@classmethod
253253
def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
254-
assert tensor.size()[dim] % mesh.size() == 0
254+
# NOTE: the following assertion does not have to be true in general. We simply enforce it for now
255+
# because the alternate case has not yet been tested/required for any model.
256+
assert tensor.size()[dim] % mesh.size() == 0, (
257+
"Tensor size along dimension to be sharded must be divisible by mesh size"
258+
)
255259

256260
# The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank)
257261
# return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()]

src/diffusers/models/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
_import_structure = {}
2626

2727
if is_torch_available():
28-
_import_structure["_modeling_parallel"] = ["ParallelConfig", "enable_parallelism"]
28+
_import_structure["_modeling_parallel"] = ["ContextParallelConfig", "ParallelConfig", "enable_parallelism"]
2929
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
3030
_import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
3131
_import_structure["auto_model"] = ["AutoModel"]
@@ -120,7 +120,7 @@
120120

121121
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
122122
if is_torch_available():
123-
from ._modeling_parallel import ParallelConfig, enable_parallelism
123+
from ._modeling_parallel import ContextParallelConfig, ParallelConfig, enable_parallelism
124124
from .adapter import MultiAdapter, T2IAdapter
125125
from .attention_dispatch import AttentionBackendName, attention_backend
126126
from .auto_model import AutoModel

src/diffusers/models/_modeling_parallel.py

Lines changed: 56 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -40,52 +40,82 @@
4040

4141

4242
@dataclass
43-
class ParallelConfig:
43+
class ContextParallelConfig:
44+
# Number of GPUs to use for ring attention within a context parallel region
4445
ring_degree: Optional[int] = None
46+
# Number of context parallel regions to use for ulysses attention within a context parallel region
4547
ulysses_degree: Optional[int] = None
46-
47-
def __post_init__(self):
48-
if self.ring_degree is None:
49-
self.ring_degree = 1
50-
if self.ulysses_degree is None:
51-
self.ulysses_degree = 1
52-
53-
54-
@dataclass
55-
class _InternalParallelConfig:
56-
rank: int
57-
world_size: int
58-
ring_degree: int
59-
ulysses_degree: int
60-
device: torch.device
61-
cp_mesh: torch.distributed.device_mesh.DeviceMesh
62-
6348
# Whether to convert output and LSE to float32 for ring attention numerical stability
6449
convert_to_fp32: bool = True
6550
# TODO: support alltoall
6651
rotate_method: Literal["allgather", "alltoall"] = "allgather"
6752

53+
_rank: int = None
54+
_world_size: int = None
55+
_device: torch.device = None
56+
_mesh: torch.distributed.device_mesh.DeviceMesh = None
6857
_flattened_mesh: torch.distributed.device_mesh.DeviceMesh = None
6958
_ring_mesh: torch.distributed.device_mesh.DeviceMesh = None
7059
_ulysses_mesh: torch.distributed.device_mesh.DeviceMesh = None
7160
_ring_local_rank: int = None
7261
_ulysses_local_rank: int = None
7362

7463
def __post_init__(self):
64+
if self.ring_degree is None:
65+
self.ring_degree = 1
66+
if self.ulysses_degree is None:
67+
self.ulysses_degree = 1
68+
69+
def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh):
70+
self._rank = rank
71+
self._world_size = world_size
72+
self._device = device
73+
self._mesh = mesh
74+
if self.ring_degree is None:
75+
self.ring_degree = 1
76+
if self.ulysses_degree is None:
77+
self.ulysses_degree = 1
7578
if self.rotate_method != "allgather":
76-
raise ValueError(f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}.")
79+
raise NotImplementedError(
80+
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
81+
)
7782
if self._flattened_mesh is None:
78-
self._flattened_mesh = self.cp_mesh._flatten()
83+
self._flattened_mesh = self._mesh._flatten()
7984
if self._ring_mesh is None:
80-
self._ring_mesh = self.cp_mesh["ring"]
85+
self._ring_mesh = self._mesh["ring"]
8186
if self._ulysses_mesh is None:
82-
self._ulysses_mesh = self.cp_mesh["ulysses"]
87+
self._ulysses_mesh = self._mesh["ulysses"]
8388
if self._ring_local_rank is None:
8489
self._ring_local_rank = self._ring_mesh.get_local_rank()
8590
if self._ulysses_local_rank is None:
8691
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
8792

8893

94+
@dataclass
95+
class ParallelConfig:
96+
context_parallel_config: Optional[ContextParallelConfig] = None
97+
98+
_rank: int = None
99+
_world_size: int = None
100+
_device: torch.device = None
101+
_cp_mesh: torch.distributed.device_mesh.DeviceMesh = None
102+
103+
def setup(
104+
self,
105+
rank: int,
106+
world_size: int,
107+
device: torch.device,
108+
*,
109+
cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
110+
):
111+
self._rank = rank
112+
self._world_size = world_size
113+
self._device = device
114+
self._cp_mesh = cp_mesh
115+
if self.context_parallel_config is not None:
116+
self.context_parallel_config.setup(rank, world_size, device, cp_mesh)
117+
118+
89119
@dataclass(frozen=True)
90120
class ContextParallelInput:
91121
split_dim: int
@@ -145,7 +175,7 @@ def enable_parallelism(model_or_pipeline: Union["DiffusionPipeline", "ModelMixin
145175
parallelized_components = [
146176
(name, component)
147177
for name, component in model_or_pipeline.components.items()
148-
if getattr(component, "_internal_parallel_config", None) is not None
178+
if getattr(component, "_parallel_config", None) is not None
149179
]
150180
if len(parallelized_components) > 1:
151181
raise ValueError(
@@ -158,7 +188,7 @@ def enable_parallelism(model_or_pipeline: Union["DiffusionPipeline", "ModelMixin
158188
)
159189
_, model_or_pipeline = parallelized_components[0]
160190
elif isinstance(model_or_pipeline, ModelMixin):
161-
if getattr(model_or_pipeline, "_internal_parallel_config", None) is None:
191+
if getattr(model_or_pipeline, "_parallel_config", None) is None:
162192
raise ValueError(
163193
"The model is not parallelized. Please ensure the model is parallelized with `.parallelize()` before using this context manager."
164194
)
@@ -167,8 +197,9 @@ def enable_parallelism(model_or_pipeline: Union["DiffusionPipeline", "ModelMixin
167197
f"Expected a `DiffusionPipeline` or `ModelMixin` instance, but got {type(model_or_pipeline)}. Please provide a valid model or pipeline."
168198
)
169199

200+
# TODO: needs to be updated when more parallelism strategies are supported
170201
old_parallel_config = _AttentionBackendRegistry._parallel_config
171-
_AttentionBackendRegistry._parallel_config = model_or_pipeline._internal_parallel_config
202+
_AttentionBackendRegistry._parallel_config = model_or_pipeline._parallel_config.context_parallel_config
172203

173204
yield
174205

src/diffusers/models/attention_dispatch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040

4141

4242
if TYPE_CHECKING:
43-
from ._modeling_parallel import _InternalParallelConfig
43+
from ._modeling_parallel import ContextParallelConfig
4444

4545
_REQUIRED_FLASH_VERSION = "2.6.3"
4646
_REQUIRED_SAGE_VERSION = "2.1.1"
@@ -193,7 +193,7 @@ class _AttentionBackendRegistry:
193193
_supports_context_parallel = {}
194194
_active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
195195
_checks_enabled = DIFFUSERS_ATTN_CHECKS
196-
_parallel_config: Optional["_InternalParallelConfig"] = None
196+
_parallel_config: Optional["ContextParallelConfig"] = None
197197

198198
@classmethod
199199
def register(
@@ -729,7 +729,7 @@ def _flash_attention_forward_op(
729729

730730
# flash-attn only returns LSE if dropout_p > 0. So, we need to workaround.
731731
parallel_config = _AttentionBackendRegistry._parallel_config
732-
if grad_enabled or (parallel_config is not None and parallel_config.world_size > 1):
732+
if grad_enabled or (parallel_config is not None and parallel_config._world_size > 1):
733733
dropout_p = dropout_p if dropout_p > 0 else 1e-30
734734

735735
with torch.set_grad_enabled(grad_enabled):

src/diffusers/models/modeling_utils.py

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
populate_model_card,
6666
)
6767
from ..utils.torch_utils import empty_device_cache
68-
from ._modeling_parallel import ContextParallelModelPlan, ParallelConfig, _InternalParallelConfig
68+
from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig
6969
from .model_loading_utils import (
7070
_caching_allocator_warmup,
7171
_determine_device_map,
@@ -249,7 +249,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
249249
_skip_layerwise_casting_patterns = None
250250
_supports_group_offloading = True
251251
_repeated_blocks = []
252-
_internal_parallel_config = None
252+
_parallel_config = None
253253
_cp_plan = None
254254

255255
def __init__(self):
@@ -1481,55 +1481,61 @@ def compile_repeated_blocks(self, *args, **kwargs):
14811481
f"Regional compilation failed because {repeated_blocks} classes are not found in the model. "
14821482
)
14831483

1484-
def enable_parallelism(self, *, config: ParallelConfig, cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None):
1484+
def enable_parallelism(
1485+
self,
1486+
*,
1487+
config: Union[ParallelConfig, ContextParallelConfig],
1488+
cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None,
1489+
):
14851490
from ..hooks.context_parallel import apply_context_parallel
14861491

14871492
logger.warning(
1488-
"`parallelize` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
1493+
"`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
14891494
)
14901495

1496+
if isinstance(config, ContextParallelConfig):
1497+
config = ParallelConfig(context_parallel_config=config)
1498+
14911499
if not torch.distributed.is_initialized():
1492-
raise RuntimeError("torch.distributed must be initialized before calling `parallelize`.")
1493-
if config.ring_degree < 1 or config.ulysses_degree < 1:
1494-
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
1495-
if config.ring_degree > 1 and config.ulysses_degree > 1:
1496-
raise ValueError(
1497-
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
1498-
)
1500+
raise RuntimeError("torch.distributed must be initialized before calling `enable_parallelism`.")
14991501

15001502
rank = torch.distributed.get_rank()
15011503
world_size = torch.distributed.get_world_size()
1502-
1503-
if config.ring_degree * config.ulysses_degree > world_size:
1504-
raise ValueError(
1505-
f"The product of `ring_degree` ({config.ring_degree}) and `ulysses_degree` ({config.ulysses_degree}) must not exceed the world size ({world_size})."
1506-
)
1507-
15081504
device_type = torch._C._get_accelerator().type
15091505
device_module = torch.get_device_module(device_type)
15101506
device = torch.device(device_type, rank % device_module.device_count())
15111507

1512-
cp_mesh = torch.distributed.device_mesh.init_device_mesh(
1513-
device_type=device_type,
1514-
mesh_shape=(config.ring_degree, config.ulysses_degree),
1515-
mesh_dim_names=("ring", "ulysses"),
1516-
)
1517-
parallel_config = _InternalParallelConfig(
1518-
rank=rank,
1519-
world_size=world_size,
1520-
ring_degree=config.ring_degree,
1521-
ulysses_degree=config.ulysses_degree,
1522-
device=device,
1523-
cp_mesh=cp_mesh,
1524-
)
1508+
cp_mesh = None
1509+
if config.context_parallel_config is not None:
1510+
cp_config = config.context_parallel_config
1511+
if cp_config.ring_degree < 1 or cp_config.ulysses_degree < 1:
1512+
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
1513+
if cp_config.ring_degree > 1 and cp_config.ulysses_degree > 1:
1514+
raise ValueError(
1515+
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
1516+
)
1517+
if cp_config.ring_degree * cp_config.ulysses_degree > world_size:
1518+
raise ValueError(
1519+
f"The product of `ring_degree` ({cp_config.ring_degree}) and `ulysses_degree` ({cp_config.ulysses_degree}) must not exceed the world size ({world_size})."
1520+
)
1521+
cp_mesh = torch.distributed.device_mesh.init_device_mesh(
1522+
device_type=device_type,
1523+
mesh_shape=(cp_config.ring_degree, cp_config.ulysses_degree),
1524+
mesh_dim_names=("ring", "ulysses"),
1525+
)
1526+
1527+
config.setup(rank, world_size, device, cp_mesh=cp_mesh)
1528+
15251529
if cp_plan is None and self._cp_plan is None:
15261530
raise ValueError(
15271531
"`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
15281532
)
15291533
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
15301534

1531-
apply_context_parallel(self, parallel_config, cp_plan)
1532-
self._internal_parallel_config = parallel_config
1535+
if config.context_parallel_config is not None:
1536+
apply_context_parallel(self, config.context_parallel_config, cp_plan)
1537+
1538+
self._parallel_config = config
15331539

15341540
@classmethod
15351541
def _load_pretrained_model(

src/diffusers/utils/dummy_pt_objects.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1053,7 +1053,7 @@ def from_pretrained(cls, *args, **kwargs):
10531053
requires_backends(cls, ["torch"])
10541054

10551055

1056-
class ParallelConfig(metaclass=DummyObject):
1056+
class ContextParallelConfig(metaclass=DummyObject):
10571057
_backends = ["torch"]
10581058

10591059
def __init__(self, *args, **kwargs):

0 commit comments

Comments
 (0)