Skip to content

Commit d13025d

Browse files
committed
[BE] Move NoParallel to torchtitan.distributed
NoParallel should not belong `expert_parallel`. This PR moves it to `torchtitan.distributed.__init__.py`.
1 parent 78d4314 commit d13025d

File tree

5 files changed

+62
-61
lines changed

5 files changed

+62
-61
lines changed

torchtitan/distributed/__init__.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,64 @@
55
# LICENSE file in the root directory of this source tree.
66

77

8+
from torch.distributed.tensor import DeviceMesh, distribute_module, DTensor, Replicate
9+
from torch.distributed.tensor.parallel import ParallelStyle
10+
from torch.distributed.tensor.placement_types import Placement
11+
812
from torchtitan.distributed.parallel_dims import ParallelDims
913

1014

11-
__all__ = ["ParallelDims"]
15+
__all__ = ["ParallelDims", "NoParallel"]
16+
17+
18+
# NOTE: This is to achieve replicate computation on the gate module in the MoE router.
19+
# It does nothing other than (1) setting the module parameters as DTensors on the given mesh
20+
# and (2) inserting hooks to module boundary to change torch.Tensor to DTensor and back.
21+
# The reason we need this wrapping is to ensure all parameters are on the same 1D/2D mesh,
22+
# which is assumed by (1) gradient norm clipping, and (2) optimizer fused implementation.
23+
class NoParallel(ParallelStyle):
24+
def __init__(
25+
self,
26+
*,
27+
input_layout: Placement | None = None,
28+
output_layout: Placement | None = None,
29+
use_local_output: bool = True,
30+
):
31+
super().__init__()
32+
self.input_layout = input_layout or Replicate()
33+
self.output_layout = output_layout or Replicate()
34+
self.desired_input_layout = Replicate()
35+
self.use_local_output = use_local_output
36+
37+
@staticmethod
38+
def _prepare_input_fn(input_layout, desired_input_layout, mod, inputs, device_mesh):
39+
# annotate module input placements/sharding with input_layouts
40+
input_tensor = inputs[0]
41+
if not isinstance(input_tensor, DTensor):
42+
input_tensor = DTensor.from_local(
43+
input_tensor, device_mesh, (input_layout,), run_check=False
44+
)
45+
46+
if input_layout != desired_input_layout:
47+
input_tensor = input_tensor.redistribute(
48+
placements=(desired_input_layout,), async_op=True
49+
)
50+
return (input_tensor, *inputs[1:])
51+
52+
@staticmethod
53+
def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh):
54+
if outputs.placements != (output_layout,):
55+
outputs = outputs.redistribute(placements=(output_layout,), async_op=True)
56+
# back to local tensor
57+
return outputs.to_local() if use_local_output else outputs
58+
59+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
60+
return distribute_module(
61+
module,
62+
device_mesh,
63+
None,
64+
partial(
65+
self._prepare_input_fn, self.input_layout, self.desired_input_layout
66+
),
67+
partial(self._prepare_output_fn, self.output_layout, self.use_local_output),
68+
)

torchtitan/distributed/expert_parallel.py

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77

8-
from functools import partial
98
from typing import Callable, Literal
109

1110
import torch
@@ -17,11 +16,9 @@
1716
distribute_module,
1817
distribute_tensor,
1918
DTensor,
20-
Replicate,
2119
Shard,
2220
)
2321
from torch.distributed.tensor.parallel import ParallelStyle
24-
from torch.distributed.tensor.placement_types import Placement
2522

2623

2724
TOKEN_GROUP_ALIGN_SIZE_M = 8
@@ -79,59 +76,6 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
7976
)
8077

8178

82-
# NOTE: This is to achieve replicate computation on the gate module in the MoE router.
83-
# It does nothing other than (1) setting the module parameters as DTensors on the given mesh
84-
# and (2) inserting hooks to module boundary to change torch.Tensor to DTensor and back.
85-
# The reason we need this wrapping is to ensure all parameters are on the same 1D/2D mesh,
86-
# which is assumed by (1) gradient norm clipping, and (2) optimizer fused implementation.
87-
class NoParallel(ParallelStyle):
88-
def __init__(
89-
self,
90-
*,
91-
input_layout: Placement | None = None,
92-
output_layout: Placement | None = None,
93-
use_local_output: bool = True,
94-
):
95-
super().__init__()
96-
self.input_layout = input_layout or Replicate()
97-
self.output_layout = output_layout or Replicate()
98-
self.desired_input_layout = Replicate()
99-
self.use_local_output = use_local_output
100-
101-
@staticmethod
102-
def _prepare_input_fn(input_layout, desired_input_layout, mod, inputs, device_mesh):
103-
# annotate module input placements/sharding with input_layouts
104-
input_tensor = inputs[0]
105-
if not isinstance(input_tensor, DTensor):
106-
input_tensor = DTensor.from_local(
107-
input_tensor, device_mesh, (input_layout,), run_check=False
108-
)
109-
110-
if input_layout != desired_input_layout:
111-
input_tensor = input_tensor.redistribute(
112-
placements=(desired_input_layout,), async_op=True
113-
)
114-
return (input_tensor, *inputs[1:])
115-
116-
@staticmethod
117-
def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh):
118-
if outputs.placements != (output_layout,):
119-
outputs = outputs.redistribute(placements=(output_layout,), async_op=True)
120-
# back to local tensor
121-
return outputs.to_local() if use_local_output else outputs
122-
123-
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
124-
return distribute_module(
125-
module,
126-
device_mesh,
127-
None,
128-
partial(
129-
self._prepare_input_fn, self.input_layout, self.desired_input_layout
130-
),
131-
partial(self._prepare_output_fn, self.output_layout, self.use_local_output),
132-
)
133-
134-
13579
class ExpertParallel(ParallelStyle):
13680
def __init__(self):
13781
super().__init__()

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919
SequenceParallel,
2020
)
2121
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
22-
from torchtitan.distributed import ParallelDims
22+
from torchtitan.distributed import NoParallel, ParallelDims
2323
from torchtitan.distributed.activation_checkpoint import apply_ac
24+
2425
from torchtitan.distributed.expert_parallel import (
2526
ExpertParallel,
2627
ExpertTensorParallel,
27-
NoParallel,
2828
ReordererSequenceParallel,
2929
TensorParallel,
3030
)

torchtitan/experiments/qwen3/infra/parallelize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
)
2222

2323
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
24-
from torchtitan.distributed import ParallelDims
24+
from torchtitan.distributed import NoParallel, ParallelDims
2525
from torchtitan.distributed.activation_checkpoint import apply_ac
2626
from torchtitan.distributed.expert_parallel import NoParallel
2727
from torchtitan.models.llama3.infra.parallelize import (

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
)
1717

1818
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
19-
from torchtitan.distributed import ParallelDims
19+
from torchtitan.distributed import NoParallel, ParallelDims
2020
from torchtitan.distributed.activation_checkpoint import apply_ac
2121
from torchtitan.distributed.expert_parallel import NoParallel
2222
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp

0 commit comments

Comments
 (0)