Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion torchtitan/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,64 @@
# LICENSE file in the root directory of this source tree.


from torch.distributed.tensor import DeviceMesh, distribute_module, DTensor, Replicate
from torch.distributed.tensor.parallel import ParallelStyle
from torch.distributed.tensor.placement_types import Placement

from torchtitan.distributed.parallel_dims import ParallelDims


__all__ = ["ParallelDims"]
__all__ = ["ParallelDims", "NoParallel"]


# NOTE: This is to achieve replicate computation on the gate module in the MoE router.
# It does nothing other than (1) setting the module parameters as DTensors on the given mesh
# and (2) inserting hooks to module boundary to change torch.Tensor to DTensor and back.
# The reason we need this wrapping is to ensure all parameters are on the same 1D/2D mesh,
# which is assumed by (1) gradient norm clipping, and (2) optimizer fused implementation.
class NoParallel(ParallelStyle):
def __init__(
self,
*,
input_layout: Placement | None = None,
output_layout: Placement | None = None,
use_local_output: bool = True,
):
super().__init__()
self.input_layout = input_layout or Replicate()
self.output_layout = output_layout or Replicate()
self.desired_input_layout = Replicate()
self.use_local_output = use_local_output

@staticmethod
def _prepare_input_fn(input_layout, desired_input_layout, mod, inputs, device_mesh):
# annotate module input placements/sharding with input_layouts
input_tensor = inputs[0]
if not isinstance(input_tensor, DTensor):
input_tensor = DTensor.from_local(
input_tensor, device_mesh, (input_layout,), run_check=False
)

if input_layout != desired_input_layout:
input_tensor = input_tensor.redistribute(
placements=(desired_input_layout,), async_op=True
)
return (input_tensor, *inputs[1:])

@staticmethod
def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh):
if outputs.placements != (output_layout,):
outputs = outputs.redistribute(placements=(output_layout,), async_op=True)
# back to local tensor
return outputs.to_local() if use_local_output else outputs

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
return distribute_module(
module,
device_mesh,
None,
partial(
self._prepare_input_fn, self.input_layout, self.desired_input_layout
),
partial(self._prepare_output_fn, self.output_layout, self.use_local_output),
)
56 changes: 0 additions & 56 deletions torchtitan/distributed/expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.


from functools import partial
from typing import Callable, Literal

import torch
Expand All @@ -16,11 +15,9 @@
distribute_module,
distribute_tensor,
DTensor,
Replicate,
Shard,
)
from torch.distributed.tensor.parallel import ParallelStyle
from torch.distributed.tensor.placement_types import Placement


# from torch.distributed._functional_collectives import all_to_all_single_autograd
Expand Down Expand Up @@ -108,59 +105,6 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
)


# NOTE: This is to achieve replicate computation on the gate module in the MoE router.
# It does nothing other than (1) setting the module parameters as DTensors on the given mesh
# and (2) inserting hooks to module boundary to change torch.Tensor to DTensor and back.
# The reason we need this wrapping is to ensure all parameters are on the same 1D/2D mesh,
# which is assumed by (1) gradient norm clipping, and (2) optimizer fused implementation.
class NoParallel(ParallelStyle):
def __init__(
self,
*,
input_layout: Placement | None = None,
output_layout: Placement | None = None,
use_local_output: bool = True,
):
super().__init__()
self.input_layout = input_layout or Replicate()
self.output_layout = output_layout or Replicate()
self.desired_input_layout = Replicate()
self.use_local_output = use_local_output

@staticmethod
def _prepare_input_fn(input_layout, desired_input_layout, mod, inputs, device_mesh):
# annotate module input placements/sharding with input_layouts
input_tensor = inputs[0]
if not isinstance(input_tensor, DTensor):
input_tensor = DTensor.from_local(
input_tensor, device_mesh, (input_layout,), run_check=False
)

if input_layout != desired_input_layout:
input_tensor = input_tensor.redistribute(
placements=(desired_input_layout,), async_op=True
)
return (input_tensor, *inputs[1:])

@staticmethod
def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh):
if outputs.placements != (output_layout,):
outputs = outputs.redistribute(placements=(output_layout,), async_op=True)
# back to local tensor
return outputs.to_local() if use_local_output else outputs

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
return distribute_module(
module,
device_mesh,
None,
partial(
self._prepare_input_fn, self.input_layout, self.desired_input_layout
),
partial(self._prepare_output_fn, self.output_layout, self.use_local_output),
)


class ExpertParallel(ParallelStyle):
def __init__(self):
super().__init__()
Expand Down
3 changes: 1 addition & 2 deletions torchtitan/experiments/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@
SequenceParallel,
)
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
from torchtitan.distributed import ParallelDims
from torchtitan.distributed import NoParallel, ParallelDims

from torchtitan.distributed.expert_parallel import (
ExpertParallel,
ExpertTensorParallel,
NoParallel,
ReordererSequenceParallel,
TensorParallel,
)
Expand Down
3 changes: 1 addition & 2 deletions torchtitan/experiments/qwen3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
)

from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
from torchtitan.distributed import ParallelDims
from torchtitan.distributed.expert_parallel import NoParallel
from torchtitan.distributed import NoParallel, ParallelDims
from torchtitan.models.llama3.infra.parallelize import (
apply_ac,
apply_compile,
Expand Down
3 changes: 1 addition & 2 deletions torchtitan/models/deepseek_v3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
)

from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
from torchtitan.distributed import ParallelDims
from torchtitan.distributed.expert_parallel import NoParallel
from torchtitan.distributed import NoParallel, ParallelDims
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
from torchtitan.experiments.llama4.infra.parallelize import (
apply_compile,
Expand Down
Loading