@@ -155,6 +155,44 @@ def __repr__(self):
155155ContextParallelModelPlan = Dict [str , Union [ContextParallelInputType , ContextParallelOutputType ]]
156156
157157
158+ # Example of a ContextParallelModelPlan (QwenImageTransformer2DModel):
159+ #
160+ # Each model should define a _cp_plan attribute that contains information on how to shard/gather
161+ # tensors at different stages of the forward:
162+ #
163+ # ```python
164+ # _cp_plan = {
165+ # "": {
166+ # "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
167+ # "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
168+ # "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
169+ # },
170+ # "pos_embed": {
171+ # 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
172+ # 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
173+ # },
174+ # "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
175+ # }
176+ # ```
177+ #
178+ # The dictionary is a set of module names mapped to their respective CP plan. The inputs/outputs of layers will be
179+ # split/gathered according to this at the respective module level. Here, the following happens:
180+ # - "":
181+ # we specify that we want to split the various inputs across the sequence dim in the pre-forward hook (i.e. before
182+ # the actual forward logic of the QwenImageTransformer2DModel is run, we will splitthe inputs)
183+ # - "pos_embed":
184+ # we specify that we want to split the outputs of the RoPE layer. Since there are two outputs (imag & text freqs),
185+ # we can individually specify how they should be split
186+ # - "proj_out":
187+ # before returning to the user, we gather the entire sequence on each rank in the post-forward hook (after the linear
188+ # layer forward has run).
189+ #
190+ # ContextParallelInput:
191+ # specifies how to split the input tensor in the pre-forward or post-forward hook of the layer it is attached to
192+ #
193+ # ContextParallelOutput:
194+ # specifies how to gather the input tensor in the post-forward hook in the layer it is attached to
195+
158196_ENABLE_PARALLELISM_WARN_ONCE = False
159197
160198
0 commit comments