Skip to content

Commit 2aab23c

Browse files
committed
[BE] Move NoParallel to torchtitan.distributed
NoParallel should not belong `expert_parallel`. This PR moves it to `torchtitan.distributed.__init__.py`.
1 parent cd337db commit 2aab23c

File tree

5 files changed

+61
-63
lines changed

5 files changed

+61
-63
lines changed

torchtitan/distributed/__init__.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,64 @@
55
# LICENSE file in the root directory of this source tree.
66

77

8+
from torch.distributed.tensor import DeviceMesh, distribute_module, DTensor, Replicate
9+
from torch.distributed.tensor.parallel import ParallelStyle
10+
from torch.distributed.tensor.placement_types import Placement
11+
812
from torchtitan.distributed.parallel_dims import ParallelDims
913

1014

11-
__all__ = ["ParallelDims"]
15+
__all__ = ["ParallelDims", "NoParallel"]
16+
17+
18+
# NOTE: This is to achieve replicate computation on the gate module in the MoE router.
19+
# It does nothing other than (1) setting the module parameters as DTensors on the given mesh
20+
# and (2) inserting hooks to module boundary to change torch.Tensor to DTensor and back.
21+
# The reason we need this wrapping is to ensure all parameters are on the same 1D/2D mesh,
22+
# which is assumed by (1) gradient norm clipping, and (2) optimizer fused implementation.
23+
class NoParallel(ParallelStyle):
24+
def __init__(
25+
self,
26+
*,
27+
input_layout: Placement | None = None,
28+
output_layout: Placement | None = None,
29+
use_local_output: bool = True,
30+
):
31+
super().__init__()
32+
self.input_layout = input_layout or Replicate()
33+
self.output_layout = output_layout or Replicate()
34+
self.desired_input_layout = Replicate()
35+
self.use_local_output = use_local_output
36+
37+
@staticmethod
38+
def _prepare_input_fn(input_layout, desired_input_layout, mod, inputs, device_mesh):
39+
# annotate module input placements/sharding with input_layouts
40+
input_tensor = inputs[0]
41+
if not isinstance(input_tensor, DTensor):
42+
input_tensor = DTensor.from_local(
43+
input_tensor, device_mesh, (input_layout,), run_check=False
44+
)
45+
46+
if input_layout != desired_input_layout:
47+
input_tensor = input_tensor.redistribute(
48+
placements=(desired_input_layout,), async_op=True
49+
)
50+
return (input_tensor, *inputs[1:])
51+
52+
@staticmethod
53+
def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh):
54+
if outputs.placements != (output_layout,):
55+
outputs = outputs.redistribute(placements=(output_layout,), async_op=True)
56+
# back to local tensor
57+
return outputs.to_local() if use_local_output else outputs
58+
59+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
60+
return distribute_module(
61+
module,
62+
device_mesh,
63+
None,
64+
partial(
65+
self._prepare_input_fn, self.input_layout, self.desired_input_layout
66+
),
67+
partial(self._prepare_output_fn, self.output_layout, self.use_local_output),
68+
)

torchtitan/distributed/expert_parallel.py

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77

8-
from functools import partial
98
from typing import Callable, Literal
109

1110
import torch
@@ -16,11 +15,9 @@
1615
distribute_module,
1716
distribute_tensor,
1817
DTensor,
19-
Replicate,
2018
Shard,
2119
)
2220
from torch.distributed.tensor.parallel import ParallelStyle
23-
from torch.distributed.tensor.placement_types import Placement
2421

2522

2623
# from torch.distributed._functional_collectives import all_to_all_single_autograd
@@ -108,59 +105,6 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
108105
)
109106

110107

111-
# NOTE: This is to achieve replicate computation on the gate module in the MoE router.
112-
# It does nothing other than (1) setting the module parameters as DTensors on the given mesh
113-
# and (2) inserting hooks to module boundary to change torch.Tensor to DTensor and back.
114-
# The reason we need this wrapping is to ensure all parameters are on the same 1D/2D mesh,
115-
# which is assumed by (1) gradient norm clipping, and (2) optimizer fused implementation.
116-
class NoParallel(ParallelStyle):
117-
def __init__(
118-
self,
119-
*,
120-
input_layout: Placement | None = None,
121-
output_layout: Placement | None = None,
122-
use_local_output: bool = True,
123-
):
124-
super().__init__()
125-
self.input_layout = input_layout or Replicate()
126-
self.output_layout = output_layout or Replicate()
127-
self.desired_input_layout = Replicate()
128-
self.use_local_output = use_local_output
129-
130-
@staticmethod
131-
def _prepare_input_fn(input_layout, desired_input_layout, mod, inputs, device_mesh):
132-
# annotate module input placements/sharding with input_layouts
133-
input_tensor = inputs[0]
134-
if not isinstance(input_tensor, DTensor):
135-
input_tensor = DTensor.from_local(
136-
input_tensor, device_mesh, (input_layout,), run_check=False
137-
)
138-
139-
if input_layout != desired_input_layout:
140-
input_tensor = input_tensor.redistribute(
141-
placements=(desired_input_layout,), async_op=True
142-
)
143-
return (input_tensor, *inputs[1:])
144-
145-
@staticmethod
146-
def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh):
147-
if outputs.placements != (output_layout,):
148-
outputs = outputs.redistribute(placements=(output_layout,), async_op=True)
149-
# back to local tensor
150-
return outputs.to_local() if use_local_output else outputs
151-
152-
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
153-
return distribute_module(
154-
module,
155-
device_mesh,
156-
None,
157-
partial(
158-
self._prepare_input_fn, self.input_layout, self.desired_input_layout
159-
),
160-
partial(self._prepare_output_fn, self.output_layout, self.use_local_output),
161-
)
162-
163-
164108
class ExpertParallel(ParallelStyle):
165109
def __init__(self):
166110
super().__init__()

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,11 @@
1919
SequenceParallel,
2020
)
2121
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
22-
from torchtitan.distributed import ParallelDims
22+
from torchtitan.distributed import NoParallel, ParallelDims
2323

2424
from torchtitan.distributed.expert_parallel import (
2525
ExpertParallel,
2626
ExpertTensorParallel,
27-
NoParallel,
2827
ReordererSequenceParallel,
2928
TensorParallel,
3029
)

torchtitan/experiments/qwen3/infra/parallelize.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121
)
2222

2323
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
24-
from torchtitan.distributed import ParallelDims
25-
from torchtitan.distributed.expert_parallel import NoParallel
24+
from torchtitan.distributed import NoParallel, ParallelDims
2625
from torchtitan.models.llama3.infra.parallelize import (
2726
apply_ac,
2827
apply_compile,

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
)
1818

1919
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
20-
from torchtitan.distributed import ParallelDims
21-
from torchtitan.distributed.expert_parallel import NoParallel
20+
from torchtitan.distributed import NoParallel, ParallelDims
2221
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
2322
from torchtitan.experiments.llama4.infra.parallelize import (
2423
apply_compile,

0 commit comments

Comments
 (0)