4141
4242@dataclass
4343class ContextParallelConfig :
44- # Number of GPUs to use for ring attention within a context parallel region
44+ """
45+ Configuration for context parallelism.
46+
47+ Args:
48+ ring_degree (`int`, *optional*, defaults to `1`):
49+ Number of devices to use for ring attention within a context parallel region. Must be a divisor of the
50+ total number of devices in the context parallel mesh.
51+ ulysses_degree (`int`, *optional*, defaults to `1`):
52+ Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the
53+ total number of devices in the context parallel mesh.
54+ convert_to_fp32 (`bool`, *optional*, defaults to `True`):
55+ Whether to convert output and LSE to float32 for ring attention numerical stability.
56+ rotate_method (`str`, *optional*, defaults to `"allgather"`):
57+ Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
58+ is supported.
59+
60+ """
61+
4562 ring_degree : Optional [int ] = None
46- # Number of context parallel regions to use for ulysses attention within a context parallel region
4763 ulysses_degree : Optional [int ] = None
48- # Whether to convert output and LSE to float32 for ring attention numerical stability
4964 convert_to_fp32 : bool = True
5065 # TODO: support alltoall
5166 rotate_method : Literal ["allgather" , "alltoall" ] = "allgather"
@@ -93,6 +108,14 @@ def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.di
93108
94109@dataclass
95110class ParallelConfig :
111+ """
112+ Configuration for applying different parallelisms.
113+
114+ Args:
115+ context_parallel_config (`ContextParallelConfig`, *optional*):
116+ Configuration for context parallelism.
117+ """
118+
96119 context_parallel_config : Optional [ContextParallelConfig ] = None
97120
98121 _rank : int = None
@@ -118,6 +141,21 @@ def setup(
118141
119142@dataclass (frozen = True )
120143class ContextParallelInput :
144+ """
145+ Configuration for splitting an input tensor across context parallel region.
146+
147+ Args:
148+ split_dim (`int`):
149+ The dimension along which to split the tensor.
150+ expected_dims (`int`, *optional*):
151+ The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the
152+ tensor has the expected number of dimensions before splitting.
153+ split_output (`bool`, *optional*, defaults to `False`):
154+ Whether to split the output tensor of the layer along the given `split_dim` instead of the input tensor.
155+ This is useful for layers whose outputs should be split after it does some preprocessing on the inputs (ex:
156+ RoPE).
157+ """
158+
121159 split_dim : int
122160 expected_dims : Optional [int ] = None
123161 split_output : bool = False
@@ -128,6 +166,17 @@ def __repr__(self):
128166
129167@dataclass (frozen = True )
130168class ContextParallelOutput :
169+ """
170+ Configuration for gathering an output tensor across context parallel region.
171+
172+ Args:
173+ gather_dim (`int`):
174+ The dimension along which to gather the tensor.
175+ expected_dims (`int`, *optional*):
176+ The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the
177+ tensor has the expected number of dimensions before gathering.
178+ """
179+
131180 gather_dim : int
132181 expected_dims : Optional [int ] = None
133182
@@ -198,6 +247,15 @@ def __repr__(self):
198247
199248@contextlib .contextmanager
200249def enable_parallelism (model_or_pipeline : Union ["DiffusionPipeline" , "ModelMixin" ]):
250+ """
251+ A context manager to set the parallelism context for models or pipelines that have been parallelized.
252+
253+ Args:
254+ model_or_pipeline (`DiffusionPipeline` or `ModelMixin`):
255+ The model or pipeline to set the parallelism context for. The model or pipeline must have been parallelized
256+ with `.enable_parallelism(ParallelConfig(...), ...)` before using this context manager.
257+ """
258+
201259 from diffusers import DiffusionPipeline , ModelMixin
202260
203261 from .attention_dispatch import _AttentionBackendRegistry
0 commit comments