Skip to content

Commit de32920

Browse files
authored
[Bugfix] Force all reduce in bwd for router at MoE boundary (#2416)
solve #2387 where user reported a bug that MoE router weights diff on each rank due to a missing all reduce under certain senarios: - `self.score_before_experts=False` for both ETP=1 and ETP=TP - `self.score_before_experts=True` and ETP=1 similar fix was proposed #2388 by @fatih-uzlmz ## Bug analysis **Case1:** When ETP=1 < TP, MoE is using TP2EP(`etp_mesh` is None, there is no ETP), `gate` is Replicated(via NoParallel), experts weight is Sharded(via EP mesh), MoE input is Replicated. _Root Cause:_ `ReordererSequenceParallell` is applied to split `top_scores` across TP ranks. Since each TP rank process its local slice of tokens through the experts, `d_routed_input` is Partial and there is no forced all-reduce. **Case2:** When ETP = TP, MoE is using DP2EP, `gate` is Replicated(via NoParallel), experts weight is Sharded(via ETP which force all reduce in bwd #1878), MoE input is Replicated on TP mesh. _Root Cause:_ The gradient flowing back to `gate.weight` is Partial when `self.score_before_experts=False`, which we incorrectly marked as Replicated and therefore results in wrong numerics. _Explanation:_ When `self.score_before_experts=True`, `d_routed_input` is already all reduced because of the fix #1878 we added to ETP. ``` routed_input = routed_input * top_scores_experts_sorted routed_output = self.experts(routed_input, ...) out_experts = routed_output_unsorted.sum(dim=1) ``` When `self.score_before_experts=False`, gradient of `top_scores` from `bmm` is Partial. The previous all reduce fix only force all reduce on `d_routed_input`, but not `d_top_scores` which is wrong. ``` routed_output = self.experts(routed_input, ...) out_experts = bmm(top_scores, routed_output_unsorted) ``` ## The fix we propose that could fix both Case1 and 2 - We revert the fix in #1878 to keep Partial dx flowing back, and delay the all reduce later at the boundary of MoE. - We choose to not use [`ColWiseParallel/RowwiseParallel`](https://github.com/pytorch/pytorch/blob/e81980ea58c65a283cf01a7132e0f420834ecf10/torch/distributed/tensor/parallel/style.py#L181) for shared expert as they force all reduce dx; but we want to keep Partial dx. So we introduced `MoEColWiseParallel/MoERowwiseParallel` which removed the input and output hooks. See the [reasoning here](https://gist.github.com/acisseJZhong/b27ecd7eca9ca1f0c050282516de2e21) for more details. - We annotate gradient of router to be partial in bwd through adding a field to NoParallel(), so that DTensor will handle the all reduce <img width="4096" height="2056" alt="image" src="https://github.com/user-attachments/assets/8c65e360-434c-4138-9c58-f99041da5389" /> <img width="4096" height="2056" alt="image" src="https://github.com/user-attachments/assets/474f431f-a469-400c-9256-77b514a40d7f" /> **Caveat:** The graph is plotted from [Claude generated computational graph ](https://gist.github.com/acisseJZhong/7101c2542530365b7b06bdf1f78068cb), please take it with a grain of salt. cc @volcacius @fatih-uzlmz
1 parent b4bbb2c commit de32920

File tree

9 files changed

+286
-151
lines changed

9 files changed

+286
-151
lines changed

torchtitan/distributed/__init__.py

Lines changed: 0 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -4,97 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
8-
from functools import partial
9-
from typing import Any
10-
11-
import torch
12-
import torch.nn as nn
13-
from torch.distributed.tensor import DeviceMesh, distribute_module, DTensor, Replicate
14-
from torch.distributed.tensor.parallel import ParallelStyle
15-
from torch.distributed.tensor.placement_types import Placement
16-
177
from torchtitan.distributed.parallel_dims import ParallelDims
188

199
__all__ = [
2010
"ParallelDims",
21-
"NoParallel",
2211
]
23-
24-
25-
# NOTE: This is to achieve replicate computation on the gate module in the MoE router.
26-
# It does nothing other than (1) setting the module parameters as DTensors on the given mesh
27-
# and (2) inserting hooks to module boundary to change torch.Tensor to DTensor and back.
28-
# The reason we need this wrapping is to ensure all parameters are on the same 1D/2D mesh,
29-
# which is assumed by (1) gradient norm clipping, and (2) optimizer fused implementation.
30-
class NoParallel(ParallelStyle):
31-
def __init__(
32-
self,
33-
*,
34-
input_layout: Placement | None = None,
35-
output_layout: Placement | None = None,
36-
use_local_output: bool = True,
37-
):
38-
super().__init__()
39-
self.input_layout = input_layout or Replicate()
40-
self.output_layout = output_layout or Replicate()
41-
self.desired_input_layout = Replicate()
42-
self.use_local_output = use_local_output
43-
44-
@staticmethod
45-
def _prepare_input_fn(
46-
input_layout: Placement | None,
47-
desired_input_layout: Placement | None,
48-
mod: nn.Module,
49-
inputs: Any,
50-
device_mesh: DeviceMesh,
51-
):
52-
# annotate module input placements/sharding with input_layouts
53-
input_tensor = inputs[0]
54-
if not isinstance(input_tensor, DTensor):
55-
assert input_layout is not None
56-
input_tensor = DTensor.from_local(
57-
input_tensor, device_mesh, (input_layout,), run_check=False
58-
)
59-
60-
if input_layout != desired_input_layout:
61-
assert input_layout is not None
62-
assert desired_input_layout is not None
63-
input_tensor = input_tensor.redistribute(
64-
placements=(desired_input_layout,), async_op=True
65-
)
66-
return (input_tensor, *inputs[1:])
67-
68-
@staticmethod
69-
def _prepare_output_fn(
70-
output_layout: Placement,
71-
use_local_output: bool,
72-
mod: nn.Module,
73-
outputs: DTensor,
74-
device_mesh: DeviceMesh,
75-
) -> torch.Tensor | DTensor:
76-
if outputs.placements != (output_layout,):
77-
outputs = outputs.redistribute(placements=(output_layout,), async_op=True)
78-
# back to local tensor
79-
return outputs.to_local() if use_local_output else outputs
80-
81-
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
82-
return distribute_module(
83-
module,
84-
device_mesh,
85-
None,
86-
partial(
87-
# TODO: this is pytorch distribute_module typing issue.
88-
# pyrefly: ignore [bad-argument-type]
89-
self._prepare_input_fn,
90-
self.input_layout,
91-
self.desired_input_layout,
92-
),
93-
partial(
94-
# TODO: this is pytorch distribute_module typing issue.
95-
# pyrefly: ignore [bad-argument-type]
96-
self._prepare_output_fn,
97-
self.output_layout,
98-
self.use_local_output,
99-
),
100-
)

torchtitan/distributed/expert_parallel.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
distribute_module,
1919
distribute_tensor,
2020
DTensor,
21-
Partial,
22-
Replicate,
2321
Shard,
2422
)
2523
from torch.distributed.tensor.parallel import ParallelStyle
@@ -47,17 +45,6 @@ def _token_combine(
4745

4846
# implementation of Tensor Parallel for the GroupedExperts in MoE
4947
class TensorParallel(ParallelStyle):
50-
def _prepare_input_fn(self, mod, inputs, device_mesh):
51-
routed_input, num_tokens_per_expert = inputs
52-
# NOTE: Currently in MoE TP, experts multiplication runs in plain Tensors.
53-
# The grad_placements on inputs is set to Partial so that necessary
54-
# reductions are performed during backward.
55-
routed_input = DTensor.from_local(
56-
routed_input, device_mesh, (Replicate(),)
57-
).to_local(grad_placements=(Partial(),))
58-
59-
return routed_input, num_tokens_per_expert
60-
6148
def _partition_fn(self, name, module, device_mesh):
6249
# w1 shape = (experts, out_dim, in_dim)
6350
module.register_parameter(
@@ -81,8 +68,6 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
8168
module,
8269
device_mesh,
8370
self._partition_fn,
84-
# pyrefly: ignore [bad-argument-type]
85-
self._prepare_input_fn,
8671
)
8772

8873

@@ -195,23 +180,6 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
195180
# This class is for dp2ep with TP (without TP we can just use ExpertParallel)
196181
class ExpertTensorParallel(ExpertParallel):
197182
def _token_dispatch(self, mod, inputs, device_mesh):
198-
routed_input, num_tokens_per_expert = inputs
199-
200-
# NOTE: Currently in MoE TP, experts multiplication runs in plain Tensors.
201-
# The grad_placements on inputs is set to Partial so that necessary
202-
# reductions are performed during backward.
203-
204-
# NOTE: The mesh used here should be dense_mesh["tp"] as routed_input is
205-
# technically wrapped with the dense_mesh["tp"] but this complicates
206-
# the interface of ExpertTensorParallel and it doesn't matter as etp
207-
# is almost always the same as tp or is 1. To avoid the complexity,
208-
# we use the etp mesh here.
209-
routed_input = DTensor.from_local(
210-
routed_input, device_mesh["etp"], (Replicate(),)
211-
).to_local(grad_placements=(Partial(),))
212-
213-
inputs = (routed_input, num_tokens_per_expert)
214-
215183
# token dispatch happens on the EP mesh, whereas device_mesh is [ep, tp] mesh
216184
return super()._token_dispatch(mod, inputs, device_mesh["ep"])
217185

torchtitan/distributed/tensor_parallel.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,189 @@
55
# LICENSE file in the root directory of this source tree.
66

77

8+
from collections.abc import Sequence
9+
from functools import partial
10+
from typing import Any
11+
812
import torch
913
import torch._inductor.config
14+
import torch.nn as nn
1015
from torch.distributed.device_mesh import DeviceMesh
16+
from torch.distributed.tensor import distribute_module, DTensor, Replicate
17+
from torch.distributed.tensor.parallel import ColwiseParallel, ParallelStyle
18+
from torch.distributed.tensor.placement_types import Placement
1119

1220
from torchtitan.config import CompileConfig, ParallelismConfig
1321
from torchtitan.tools.logging import logger
1422

1523

24+
class NoParallel(ParallelStyle):
25+
"""Replicate computation on the TP mesh without sharding.
26+
27+
This style does nothing other than:
28+
(1) setting the module parameters as DTensors on the given mesh, and
29+
(2) inserting hooks at module boundary to convert torch.Tensor to DTensor and back.
30+
31+
The reason we need this wrapping is to ensure all parameters are on the same 1D/2D mesh,
32+
which is assumed by (1) gradient norm clipping, and (2) optimizer fused implementation.
33+
34+
Used for modules like the MoE router gate that need replicated computation on TP mesh.
35+
"""
36+
37+
def __init__(
38+
self,
39+
*,
40+
input_layout: Placement | None = None,
41+
output_layout: Placement | None = None,
42+
local_output_grad_placements: Sequence[Placement] | None = None,
43+
):
44+
super().__init__()
45+
self.input_layout = input_layout or Replicate()
46+
self.output_layout = output_layout or Replicate()
47+
self.desired_input_layout = Replicate()
48+
# If None, output stays as DTensor.
49+
# If provided, output is cast to local tensor via
50+
# to_local(grad_placements=local_output_grad_placements).
51+
self.local_output_grad_placements = local_output_grad_placements
52+
53+
@staticmethod
54+
def _prepare_input_fn(
55+
input_layout: Placement | None,
56+
desired_input_layout: Placement | None,
57+
mod: nn.Module,
58+
inputs: Any,
59+
device_mesh: DeviceMesh,
60+
):
61+
input_tensor = inputs[0]
62+
if not isinstance(input_tensor, DTensor):
63+
assert input_layout is not None
64+
input_tensor = DTensor.from_local(
65+
input_tensor, device_mesh, (input_layout,), run_check=False
66+
)
67+
68+
if input_layout != desired_input_layout:
69+
assert input_layout is not None
70+
assert desired_input_layout is not None
71+
input_tensor = input_tensor.redistribute(
72+
placements=(desired_input_layout,), async_op=True
73+
)
74+
return (input_tensor, *inputs[1:])
75+
76+
@staticmethod
77+
def _prepare_output_fn(
78+
output_layout: Placement,
79+
local_output_grad_placements: Sequence[Placement] | None,
80+
mod: nn.Module,
81+
outputs: DTensor,
82+
device_mesh: DeviceMesh,
83+
) -> torch.Tensor | DTensor:
84+
if outputs.placements != (output_layout,):
85+
outputs = outputs.redistribute(placements=(output_layout,), async_op=True)
86+
if local_output_grad_placements is not None:
87+
return outputs.to_local(grad_placements=local_output_grad_placements)
88+
else:
89+
return outputs
90+
91+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
92+
return distribute_module(
93+
module,
94+
device_mesh,
95+
None,
96+
partial(
97+
self._prepare_input_fn, # pyrefly: ignore [bad-argument-type]
98+
self.input_layout,
99+
self.desired_input_layout,
100+
),
101+
partial(
102+
self._prepare_output_fn, # pyrefly: ignore [bad-argument-type]
103+
self.output_layout,
104+
self.local_output_grad_placements,
105+
),
106+
)
107+
108+
109+
class ColwiseParallelWithGradPlacement(ColwiseParallel):
110+
"""ColwiseParallel with explicit control over backward gradient placement.
111+
112+
By default, ``ColwiseParallel`` with ``input_layouts=Replicate()`` wraps
113+
the input via ``from_local(Replicate)``, whose backward all-reduces d_x
114+
back to Replicate. This subclass overrides ``_prepare_input_fn`` to pass
115+
``local_input_grad_placements`` to ``DTensor.from_local``, giving users
116+
explicit control over the gradient placement during backward. When not
117+
specified, defaults to ``None`` and the gradient placement follows the
118+
default guarantees of ``DTensor.from_local``.
119+
"""
120+
121+
def __init__(
122+
self,
123+
*,
124+
input_layouts: Placement | None = None,
125+
output_layouts: Placement | None = None,
126+
use_local_output: bool = True,
127+
local_input_grad_placements: Sequence[Placement] | None = None,
128+
):
129+
super().__init__(
130+
input_layouts=input_layouts,
131+
output_layouts=output_layouts,
132+
use_local_output=use_local_output,
133+
)
134+
self.local_input_grad_placements = local_input_grad_placements
135+
136+
@staticmethod
137+
def _prepare_input_fn( # pyrefly: ignore [bad-param-name-override]
138+
input_layouts,
139+
desired_input_layouts,
140+
local_input_grad_placements,
141+
mod,
142+
inputs,
143+
device_mesh,
144+
):
145+
input_tensor = inputs[0]
146+
if not isinstance(input_tensor, DTensor):
147+
assert local_input_grad_placements is not None, (
148+
"local_input_grad_placements must be specified when input is a "
149+
"plain tensor. Please think about what you want the from_local(Replicate) backward behavior like."
150+
)
151+
input_tensor = DTensor.from_local(
152+
input_tensor,
153+
device_mesh,
154+
input_layouts,
155+
run_check=False,
156+
grad_placements=local_input_grad_placements,
157+
)
158+
159+
if input_layouts != desired_input_layouts:
160+
input_tensor = input_tensor.redistribute(
161+
placements=desired_input_layouts, async_op=True
162+
)
163+
return input_tensor
164+
165+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
166+
if isinstance(module, nn.Linear):
167+
partition_fn = self._partition_linear_fn
168+
elif isinstance(module, nn.Embedding):
169+
partition_fn = self._partition_embedding_fn
170+
else:
171+
raise NotImplementedError(
172+
"ColwiseParallelWithGradPlacement currently only supports nn.Linear and nn.Embedding!"
173+
)
174+
175+
return distribute_module(
176+
module,
177+
device_mesh,
178+
partition_fn,
179+
partial(
180+
self._prepare_input_fn, # pyrefly: ignore [bad-argument-type]
181+
self.input_layouts,
182+
self.desired_input_layouts,
183+
self.local_input_grad_placements,
184+
),
185+
partial(
186+
self._prepare_output_fn, self.output_layouts, self.use_local_output
187+
),
188+
)
189+
190+
16191
def maybe_enable_async_tp(
17192
parallelism: ParallelismConfig, compile_config: CompileConfig, tp_mesh: DeviceMesh
18193
):

0 commit comments

Comments
 (0)