Skip to content

Commit c88fc99

Browse files
committed
add explanation
1 parent f35483a commit c88fc99

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

src/diffusers/models/_modeling_parallel.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,44 @@ def __repr__(self):
155155
ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]]
156156

157157

158+
# Example of a ContextParallelModelPlan (QwenImageTransformer2DModel):
159+
#
160+
# Each model should define a _cp_plan attribute that contains information on how to shard/gather
161+
# tensors at different stages of the forward:
162+
#
163+
# ```python
164+
# _cp_plan = {
165+
# "": {
166+
# "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
167+
# "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
168+
# "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
169+
# },
170+
# "pos_embed": {
171+
# 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
172+
# 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
173+
# },
174+
# "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
175+
# }
176+
# ```
177+
#
178+
# The dictionary is a set of module names mapped to their respective CP plan. The inputs/outputs of layers will be
179+
# split/gathered according to this at the respective module level. Here, the following happens:
180+
# - "":
181+
# we specify that we want to split the various inputs across the sequence dim in the pre-forward hook (i.e. before
182+
# the actual forward logic of the QwenImageTransformer2DModel is run, we will splitthe inputs)
183+
# - "pos_embed":
184+
# we specify that we want to split the outputs of the RoPE layer. Since there are two outputs (imag & text freqs),
185+
# we can individually specify how they should be split
186+
# - "proj_out":
187+
# before returning to the user, we gather the entire sequence on each rank in the post-forward hook (after the linear
188+
# layer forward has run).
189+
#
190+
# ContextParallelInput:
191+
# specifies how to split the input tensor in the pre-forward or post-forward hook of the layer it is attached to
192+
#
193+
# ContextParallelOutput:
194+
# specifies how to gather the input tensor in the post-forward hook in the layer it is attached to
195+
158196
_ENABLE_PARALLELISM_WARN_ONCE = False
159197

160198

0 commit comments

Comments
 (0)