diff --git a/torchtitan/distributed/__init__.py b/torchtitan/distributed/__init__.py index ff8451ec3..131262eaf 100644 --- a/torchtitan/distributed/__init__.py +++ b/torchtitan/distributed/__init__.py @@ -5,7 +5,64 @@ # LICENSE file in the root directory of this source tree. +from torch.distributed.tensor import DeviceMesh, distribute_module, DTensor, Replicate +from torch.distributed.tensor.parallel import ParallelStyle +from torch.distributed.tensor.placement_types import Placement + from torchtitan.distributed.parallel_dims import ParallelDims -__all__ = ["ParallelDims"] +__all__ = ["ParallelDims", "NoParallel"] + + +# NOTE: This is to achieve replicate computation on the gate module in the MoE router. +# It does nothing other than (1) setting the module parameters as DTensors on the given mesh +# and (2) inserting hooks to module boundary to change torch.Tensor to DTensor and back. +# The reason we need this wrapping is to ensure all parameters are on the same 1D/2D mesh, +# which is assumed by (1) gradient norm clipping, and (2) optimizer fused implementation. +class NoParallel(ParallelStyle): + def __init__( + self, + *, + input_layout: Placement | None = None, + output_layout: Placement | None = None, + use_local_output: bool = True, + ): + super().__init__() + self.input_layout = input_layout or Replicate() + self.output_layout = output_layout or Replicate() + self.desired_input_layout = Replicate() + self.use_local_output = use_local_output + + @staticmethod + def _prepare_input_fn(input_layout, desired_input_layout, mod, inputs, device_mesh): + # annotate module input placements/sharding with input_layouts + input_tensor = inputs[0] + if not isinstance(input_tensor, DTensor): + input_tensor = DTensor.from_local( + input_tensor, device_mesh, (input_layout,), run_check=False + ) + + if input_layout != desired_input_layout: + input_tensor = input_tensor.redistribute( + placements=(desired_input_layout,), async_op=True + ) + return (input_tensor, *inputs[1:]) + + @staticmethod + def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh): + if outputs.placements != (output_layout,): + outputs = outputs.redistribute(placements=(output_layout,), async_op=True) + # back to local tensor + return outputs.to_local() if use_local_output else outputs + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + None, + partial( + self._prepare_input_fn, self.input_layout, self.desired_input_layout + ), + partial(self._prepare_output_fn, self.output_layout, self.use_local_output), + ) diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index 384d9e33f..0f9762b89 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. -from functools import partial from typing import Callable, Literal import torch @@ -16,11 +15,9 @@ distribute_module, distribute_tensor, DTensor, - Replicate, Shard, ) from torch.distributed.tensor.parallel import ParallelStyle -from torch.distributed.tensor.placement_types import Placement # from torch.distributed._functional_collectives import all_to_all_single_autograd @@ -108,59 +105,6 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: ) -# NOTE: This is to achieve replicate computation on the gate module in the MoE router. -# It does nothing other than (1) setting the module parameters as DTensors on the given mesh -# and (2) inserting hooks to module boundary to change torch.Tensor to DTensor and back. -# The reason we need this wrapping is to ensure all parameters are on the same 1D/2D mesh, -# which is assumed by (1) gradient norm clipping, and (2) optimizer fused implementation. -class NoParallel(ParallelStyle): - def __init__( - self, - *, - input_layout: Placement | None = None, - output_layout: Placement | None = None, - use_local_output: bool = True, - ): - super().__init__() - self.input_layout = input_layout or Replicate() - self.output_layout = output_layout or Replicate() - self.desired_input_layout = Replicate() - self.use_local_output = use_local_output - - @staticmethod - def _prepare_input_fn(input_layout, desired_input_layout, mod, inputs, device_mesh): - # annotate module input placements/sharding with input_layouts - input_tensor = inputs[0] - if not isinstance(input_tensor, DTensor): - input_tensor = DTensor.from_local( - input_tensor, device_mesh, (input_layout,), run_check=False - ) - - if input_layout != desired_input_layout: - input_tensor = input_tensor.redistribute( - placements=(desired_input_layout,), async_op=True - ) - return (input_tensor, *inputs[1:]) - - @staticmethod - def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh): - if outputs.placements != (output_layout,): - outputs = outputs.redistribute(placements=(output_layout,), async_op=True) - # back to local tensor - return outputs.to_local() if use_local_output else outputs - - def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: - return distribute_module( - module, - device_mesh, - None, - partial( - self._prepare_input_fn, self.input_layout, self.desired_input_layout - ), - partial(self._prepare_output_fn, self.output_layout, self.use_local_output), - ) - - class ExpertParallel(ParallelStyle): def __init__(self): super().__init__() diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index e51168657..4372ba7eb 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -19,12 +19,11 @@ SequenceParallel, ) from torchtitan.config import JobConfig, TORCH_DTYPE_MAP -from torchtitan.distributed import ParallelDims +from torchtitan.distributed import NoParallel, ParallelDims from torchtitan.distributed.expert_parallel import ( ExpertParallel, ExpertTensorParallel, - NoParallel, ReordererSequenceParallel, TensorParallel, ) diff --git a/torchtitan/experiments/qwen3/infra/parallelize.py b/torchtitan/experiments/qwen3/infra/parallelize.py index 8f7cb06ef..e20bac88d 100644 --- a/torchtitan/experiments/qwen3/infra/parallelize.py +++ b/torchtitan/experiments/qwen3/infra/parallelize.py @@ -21,8 +21,7 @@ ) from torchtitan.config import JobConfig, TORCH_DTYPE_MAP -from torchtitan.distributed import ParallelDims -from torchtitan.distributed.expert_parallel import NoParallel +from torchtitan.distributed import NoParallel, ParallelDims from torchtitan.models.llama3.infra.parallelize import ( apply_ac, apply_compile, diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 8423c2a8e..c74c802a8 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -17,8 +17,7 @@ ) from torchtitan.config import JobConfig, TORCH_DTYPE_MAP -from torchtitan.distributed import ParallelDims -from torchtitan.distributed.expert_parallel import NoParallel +from torchtitan.distributed import NoParallel, ParallelDims from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp from torchtitan.experiments.llama4.infra.parallelize import ( apply_compile,