Skip to content

Commit b85c26c

Browse files
committed
add docstrings
1 parent e569785 commit b85c26c

File tree

1 file changed

+61
-3
lines changed

1 file changed

+61
-3
lines changed

src/diffusers/models/_modeling_parallel.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,26 @@
4141

4242
@dataclass
4343
class ContextParallelConfig:
44-
# Number of GPUs to use for ring attention within a context parallel region
44+
"""
45+
Configuration for context parallelism.
46+
47+
Args:
48+
ring_degree (`int`, *optional*, defaults to `1`):
49+
Number of devices to use for ring attention within a context parallel region. Must be a divisor of the
50+
total number of devices in the context parallel mesh.
51+
ulysses_degree (`int`, *optional*, defaults to `1`):
52+
Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the
53+
total number of devices in the context parallel mesh.
54+
convert_to_fp32 (`bool`, *optional*, defaults to `True`):
55+
Whether to convert output and LSE to float32 for ring attention numerical stability.
56+
rotate_method (`str`, *optional*, defaults to `"allgather"`):
57+
Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
58+
is supported.
59+
60+
"""
61+
4562
ring_degree: Optional[int] = None
46-
# Number of context parallel regions to use for ulysses attention within a context parallel region
4763
ulysses_degree: Optional[int] = None
48-
# Whether to convert output and LSE to float32 for ring attention numerical stability
4964
convert_to_fp32: bool = True
5065
# TODO: support alltoall
5166
rotate_method: Literal["allgather", "alltoall"] = "allgather"
@@ -93,6 +108,14 @@ def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.di
93108

94109
@dataclass
95110
class ParallelConfig:
111+
"""
112+
Configuration for applying different parallelisms.
113+
114+
Args:
115+
context_parallel_config (`ContextParallelConfig`, *optional*):
116+
Configuration for context parallelism.
117+
"""
118+
96119
context_parallel_config: Optional[ContextParallelConfig] = None
97120

98121
_rank: int = None
@@ -118,6 +141,21 @@ def setup(
118141

119142
@dataclass(frozen=True)
120143
class ContextParallelInput:
144+
"""
145+
Configuration for splitting an input tensor across context parallel region.
146+
147+
Args:
148+
split_dim (`int`):
149+
The dimension along which to split the tensor.
150+
expected_dims (`int`, *optional*):
151+
The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the
152+
tensor has the expected number of dimensions before splitting.
153+
split_output (`bool`, *optional*, defaults to `False`):
154+
Whether to split the output tensor of the layer along the given `split_dim` instead of the input tensor.
155+
This is useful for layers whose outputs should be split after it does some preprocessing on the inputs (ex:
156+
RoPE).
157+
"""
158+
121159
split_dim: int
122160
expected_dims: Optional[int] = None
123161
split_output: bool = False
@@ -128,6 +166,17 @@ def __repr__(self):
128166

129167
@dataclass(frozen=True)
130168
class ContextParallelOutput:
169+
"""
170+
Configuration for gathering an output tensor across context parallel region.
171+
172+
Args:
173+
gather_dim (`int`):
174+
The dimension along which to gather the tensor.
175+
expected_dims (`int`, *optional*):
176+
The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the
177+
tensor has the expected number of dimensions before gathering.
178+
"""
179+
131180
gather_dim: int
132181
expected_dims: Optional[int] = None
133182

@@ -198,6 +247,15 @@ def __repr__(self):
198247

199248
@contextlib.contextmanager
200249
def enable_parallelism(model_or_pipeline: Union["DiffusionPipeline", "ModelMixin"]):
250+
"""
251+
A context manager to set the parallelism context for models or pipelines that have been parallelized.
252+
253+
Args:
254+
model_or_pipeline (`DiffusionPipeline` or `ModelMixin`):
255+
The model or pipeline to set the parallelism context for. The model or pipeline must have been parallelized
256+
with `.enable_parallelism(ParallelConfig(...), ...)` before using this context manager.
257+
"""
258+
201259
from diffusers import DiffusionPipeline, ModelMixin
202260

203261
from .attention_dispatch import _AttentionBackendRegistry

0 commit comments

Comments
 (0)