From 565a5f62935409025893417113c5c8d54d5b5a5e Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 28 Aug 2025 21:37:38 -0700 Subject: [PATCH 1/8] Use new DeviceMesh unflatten to rewrite parallel_dims This is a demonstration of how parallel_dims will be when using https://github.com/pytorch/pytorch/pull/161224 stack. --- torchtitan/distributed/parallel_dims.py | 141 +++++++++++++++++++++++- torchtitan/train.py | 8 +- 2 files changed, 143 insertions(+), 6 deletions(-) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index dbb443c6b..3144dfdbe 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from collections import defaultdict from dataclasses import dataclass from torch.distributed.device_mesh import DeviceMesh, init_device_mesh @@ -25,6 +26,7 @@ class ParallelDims: ep: int etp: int world_size: int + mesh_dim_names: tuple[str] = tuple() _world_mesh: DeviceMesh = None @@ -63,6 +65,134 @@ def _validate(self): # EP would borrow all cp and tp and some dp_shard degree assert ep % (cp * tp) == 0 and (dp_shard * cp * tp) % ep == 0 + def build_mesh(self) -> "ParallelDims": + """Build the device mesh with the required mesh dimensions. + + The following mesh dimensions may be created based on the parallel configuration: + + pp: For PP. + dp_replicate: For DDP or HSDP replicate dimension. + dp_shard_cp: For FSDP or HSDP shard dimension. This includes + ``cp`` even if ``cp`` is 1, so we just use the name + ``dp_shard_cp``. As a result, we always use the name + ``dp_shard_cp`` and ``dp_shard`` is not created as a + dimension. + dp_cp: This is used by loss all-reduce. It includes ``dp_replicate``, + ``dp_shard``, and ``cp`` as all of them are data parallelisms. + dp: This is used by data loading. It includes both ``dp_replicate`` + and ``dp_shard``. + The naming can be confusing; ``batch`` could be a better name. + cp: For CP. + tp: For TP. + ep: For EP. + dp_shard_mod_ep: For FSDP or HSDP shard dimension in EP region. + + Note: These dimensions won't exist at the same time. The meshes we need to + unflatten from world_mesh, assuming all degrees are > 1 except for ``pp``: + + ["dp", "cp", "tp"]: ``dp`` process group is wasted as dataloader + doesn't need it. + + ["dp_cp", "tp"]: loss computation + + ["dp_replicate", "dp_shard_cp", "tp"]: Non-EP region computation. + + ["dp_replicate", "dp_shard_mod_ep", "ep", "tp"]: EP region computation if etp == tp. + + ["dp_replicate", "dp_shard_mod_ep", "ep"]: EP region computation if etp == 1. + """ + + def add_dim(name, degree, config): + config["name"].append(name) + config["degree"].append(degree) + + world_mesh = init_device_mesh(device_type, [self.world_size]) + dp_shard_mod_ep = ( + self.dp_shard * self.cp // self.ep + if self.etp == self.tp + else self.dp_shard * self.cp * self.tp // self.ep + ) + + data_mesh_dims = defaultdict(list) + non_ep_computation_dims = defaultdict(list) + ep_computation_dims = defaultdict(list) + + if self.pp_enabled: + add_dim("pp", self.pp, data_mesh_dims) + add_dim("pp", self.pp, non_ep_computation_dims) + add_dim("pp", self.pp, ep_computation_dims) + + if self.dp_enabled: + add_dim("dp", self.dp_replicate * self.dp_shard, data_mesh_dims) + if self.dp_replicate_enabled: + add_dim("dp_replicate", self.dp_replicate, non_ep_computation_dims) + add_dim("dp_replicate", self.dp_replicate, ep_computation_dims) + if self.dp_shard_enabled: + add_dim("dp_shard_cp", self.dp_shard * self.cp, non_ep_computation_dims) + add_dim("dp_shard_mod_ep", dp_shard_mod_ep, ep_computation_dims) + + if self.cp_enabled: + add_dim("cp", self.cp, data_mesh_dims) + + if self.tp_enabled: + add_dim("tp", self.tp, data_mesh_dims, non_ep_computation_dims) + if self.etp == self.tp: + add_dim("tp", self.tp, ep_computation_dims) + + self._all_meshes = [] + + if self.dp_enabled: + data_mesh = world_mesh._unflatten( + 0, data_mesh_dims["degree"], data_mesh_dims["name"] + ) + self._all_meshes.append(data_mesh) + # Note that we don't create loss_mesh as it is easier to flatten + # from data_mesh + if self.cp_enabled: + self._all_meshes[-1]["dp", "cp"]._flatten(mesh_dim_name="dp_cp") + else: + self._all_meshes[-1]["dp"]._flatten(mesh_dim_name="dp_cp") + + if self.dp_cp_enabled or self.tp_enabled or self.pp_enabled: + self._all_meshes.append( + world_mesh._unflatten( + 0, + non_ep_computation_dims["degree"], + non_ep_computation_dims["name"], + ) + ) + + if self.ep_enabled: + add_dim("ep", self.ep, ep_computation_dims) + self._all_meshes.append( + world_mesh._unflatten( + 0, ep_computation_dims["degree"], ep_computation_dims["name"] + ) + ) + + self._world_mesh = world_mesh + self.mesh_dim_names = tuple( + name for m in self._all_meshes for name in m.mesh_dim_names + ) + return self + + def __getitem__(self, name): + # This is a hack to make ParallelDims behave like a DeviceMesh. + # We will need to change trainer if design is concluded. For now, + # this is just a quick hack to make it work with unflatten() + + if "mesh_dim_names" == name: + return [name for m in self._all_meshes for name in m.mesh_dim_names] + + for mesh in self._all_meshes: + try: + submesh = mesh[name] + return submesh + except KeyError: + pass + raise AttributeError(f"ParallelDims has no attribute {name}") + + """ def build_mesh(self) -> DeviceMesh: # TODO: Current implementation of ParallelDims for dp2ep Expert Parallel # is not very clean, due to the limited support from DeviceMesh @@ -188,14 +318,19 @@ def _build_mesh_without_ep(self) -> DeviceMesh: mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") return mesh + """ @property - def world_mesh(self) -> str: + def world_mesh(self) -> "ParallelDims": + # This is a hack to make ParallelDims behave like a DeviceMesh. + # We will need to change trainer if design is concluded. For now, + # this is just a quick hack to make it work with unflatten() + # doing late init so ParallelDims can still be used as a lightweight # dataclass without having to initialize the world mesh if self._world_mesh is None: - self._world_mesh = self.build_mesh() - return self._world_mesh + self.build_mesh() + return self @property def dp_enabled(self): diff --git a/torchtitan/train.py b/torchtitan/train.py index 9b69fd679..52489eabc 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -11,9 +11,9 @@ from typing import Any, Generator, Iterable, Optional import torch -from torch.distributed.elastic.multiprocessing.errors import record import torchtitan.protocols.train_spec as train_spec_module +from torch.distributed.elastic.multiprocessing.errors import record from torchtitan.components.checkpoint import CheckpointManager from torchtitan.components.dataloader import DataloaderExhaustedError from torchtitan.components.ft import FTManager, maybe_semi_sync_training @@ -123,12 +123,14 @@ def __init__(self, job_config: JobConfig): # Set random seed, and maybe enable deterministic mode # (mainly for debugging, expect perf loss). + """ dist_utils.set_determinism( - world_mesh, + world_mesh._world_mesh, self.device, job_config.training.seed, job_config.training.deterministic, ) + """ self.train_spec = train_spec_module.get_train_spec(job_config.model.name) # build tokenizer and dataloader @@ -611,7 +613,7 @@ def train(self): timeout=timedelta( seconds=job_config.comm.train_timeout_seconds ), - world_mesh=self.parallel_dims.world_mesh, + world_mesh=self.parallel_dims._world_mesh, ) if torch.distributed.get_rank() == 0: From 5fa85d96ffb4cbf9b54d3b7871a0870034a687ca Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 28 Aug 2025 23:00:23 -0700 Subject: [PATCH 2/8] misc --- torchtitan/distributed/parallel_dims.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 3144dfdbe..c06500e26 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -87,19 +87,24 @@ def build_mesh(self) -> "ParallelDims": ep: For EP. dp_shard_mod_ep: For FSDP or HSDP shard dimension in EP region. - Note: These dimensions won't exist at the same time. The meshes we need to - unflatten from world_mesh, assuming all degrees are > 1 except for ``pp``: + Note: These dimensions won't exist at the same time. If we consider + unflatten() operator only, following are all the meshes required + assuming all degrees are > 1 except for ``pp``: ["dp", "cp", "tp"]: ``dp`` process group is wasted as dataloader doesn't need it. - ["dp_cp", "tp"]: loss computation - ["dp_replicate", "dp_shard_cp", "tp"]: Non-EP region computation. - ["dp_replicate", "dp_shard_mod_ep", "ep", "tp"]: EP region computation if etp == tp. - ["dp_replicate", "dp_shard_mod_ep", "ep"]: EP region computation if etp == 1. + + In reality, we don't actually need to create all of these meshes. + For example, ``dp_cp`` can be sliced and flattened from ["dp", "cp", "tp"]. + So we don't actually need to create ["dp_cp", "tp"]. + + But there are some meses we MUST create if that mesh will be used for a + parameter. So Non-EP-region-computation mesh and EP-region-computation mesh + are required. """ def add_dim(name, degree, config): From 234f80e533b270e71d49e31ede8341eefe342a40 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 28 Aug 2025 23:01:07 -0700 Subject: [PATCH 3/8] Delete legacy code --- torchtitan/distributed/parallel_dims.py | 128 ------------------------ 1 file changed, 128 deletions(-) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index c06500e26..6b0a3de79 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -197,134 +197,6 @@ def __getitem__(self, name): pass raise AttributeError(f"ParallelDims has no attribute {name}") - """ - def build_mesh(self) -> DeviceMesh: - # TODO: Current implementation of ParallelDims for dp2ep Expert Parallel - # is not very clean, due to the limited support from DeviceMesh - # for creating two staggered meshes. Will improve. - if self.ep > 1: - return self._build_mesh_with_ep() - else: - return self._build_mesh_without_ep() - - def _build_mesh_with_ep(self) -> DeviceMesh: - # With ep, dp_shard and ep are derived submeshes: - # dp_shard = dp_shard_mod_ep * dp_shard_in_ep - if self.etp == self.tp: - # ep = dp_shard_in_ep * cp - dp_shard_mod_ep = self.dp_shard * self.cp // self.ep - dp_shard_in_ep = self.ep // self.cp - else: - assert self.etp == 1 - # ep = dp_shard_in_ep * cp * tp - dp_shard_mod_ep = self.dp_shard * self.cp * self.tp // self.ep - dp_shard_in_ep = self.ep // (self.cp * self.tp) - - dims = [] - names = [] - for d, name in zip( - [ - self.pp, - self.dp_replicate, - dp_shard_mod_ep, - dp_shard_in_ep, - self.cp, - self.tp, - ], - ["pp", "dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp", "tp"], - ): - # dp_shard_mod_ep is needed even if it's 1, whose FSDP wrapping - # helps the MoE layers do mixed precision training - if d > 1 or name == "dp_shard_mod_ep": - dims.append(d) - names.append(name) - - logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") - mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) - - # Create all the submesh here to ensure all required process groups are - # initialized: - # Mesh for data loading (no communication on this mesh) - dp_mesh_dim_names = [] - # Mesh for param sharding - dp_shard_cp_mesh_dim_names = [] - # Mesh for loss all-reduce - dp_cp_mesh_dim_names = [] - # Mesh for ep - ep_mesh_dim_names = [] - - if self.dp_replicate_enabled: - dp_mesh_dim_names.append("dp_replicate") - dp_cp_mesh_dim_names.append("dp_replicate") - # dp_shard_mod_ep is always needed, even if it's 1 - dp_mesh_dim_names.append("dp_shard_mod_ep") - dp_shard_cp_mesh_dim_names.append("dp_shard_mod_ep") - dp_cp_mesh_dim_names.append("dp_shard_mod_ep") - if "dp_shard_in_ep" in names: - dp_mesh_dim_names.append("dp_shard_in_ep") - dp_shard_cp_mesh_dim_names.append("dp_shard_in_ep") - dp_cp_mesh_dim_names.append("dp_shard_in_ep") - ep_mesh_dim_names.append("dp_shard_in_ep") - if self.cp_enabled: - dp_shard_cp_mesh_dim_names.append("cp") - dp_cp_mesh_dim_names.append("cp") - ep_mesh_dim_names.append("cp") - if self.etp == 1 and self.tp_enabled: - ep_mesh_dim_names.append("tp") - - mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") - mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_shard_cp") - mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") - mesh[tuple(ep_mesh_dim_names)]._flatten(mesh_dim_name="ep") - - return mesh - - def _build_mesh_without_ep(self) -> DeviceMesh: - dims = [] - names = [] - for d, name in zip( - [self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp], - ["pp", "dp_replicate", "dp_shard", "cp", "tp"], - ): - if d > 1: - dims.append(d) - names.append(name) - - logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") - mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) - - # Create all the submesh here to ensure all required process groups are - # initialized: - # Mesh for data loading (no communication on this mesh) - dp_mesh_dim_names = [] - # Mesh for param sharding - dp_shard_cp_mesh_dim_names = [] - # Mesh for loss all-reduce - dp_cp_mesh_dim_names = [] - - if self.dp_replicate_enabled: - dp_mesh_dim_names.append("dp_replicate") - dp_cp_mesh_dim_names.append("dp_replicate") - if self.dp_shard_enabled: - dp_mesh_dim_names.append("dp_shard") - dp_shard_cp_mesh_dim_names.append("dp_shard") - dp_cp_mesh_dim_names.append("dp_shard") - if self.cp_enabled: - dp_shard_cp_mesh_dim_names.append("cp") - dp_cp_mesh_dim_names.append("cp") - - if dp_mesh_dim_names != []: - mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") - if dp_shard_cp_mesh_dim_names != []: - mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten( - mesh_dim_name="dp_shard_cp" - ) - if dp_cp_mesh_dim_names != []: - mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") - - return mesh - """ - @property def world_mesh(self) -> "ParallelDims": # This is a hack to make ParallelDims behave like a DeviceMesh. From baaa3ea3fabb7975626a1f924d0e73e12554e104 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 28 Aug 2025 23:02:55 -0700 Subject: [PATCH 4/8] misc --- torchtitan/distributed/parallel_dims.py | 128 ++++++++++++++++++++++++ 1 file changed, 128 insertions(+) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 6b0a3de79..c06500e26 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -197,6 +197,134 @@ def __getitem__(self, name): pass raise AttributeError(f"ParallelDims has no attribute {name}") + """ + def build_mesh(self) -> DeviceMesh: + # TODO: Current implementation of ParallelDims for dp2ep Expert Parallel + # is not very clean, due to the limited support from DeviceMesh + # for creating two staggered meshes. Will improve. + if self.ep > 1: + return self._build_mesh_with_ep() + else: + return self._build_mesh_without_ep() + + def _build_mesh_with_ep(self) -> DeviceMesh: + # With ep, dp_shard and ep are derived submeshes: + # dp_shard = dp_shard_mod_ep * dp_shard_in_ep + if self.etp == self.tp: + # ep = dp_shard_in_ep * cp + dp_shard_mod_ep = self.dp_shard * self.cp // self.ep + dp_shard_in_ep = self.ep // self.cp + else: + assert self.etp == 1 + # ep = dp_shard_in_ep * cp * tp + dp_shard_mod_ep = self.dp_shard * self.cp * self.tp // self.ep + dp_shard_in_ep = self.ep // (self.cp * self.tp) + + dims = [] + names = [] + for d, name in zip( + [ + self.pp, + self.dp_replicate, + dp_shard_mod_ep, + dp_shard_in_ep, + self.cp, + self.tp, + ], + ["pp", "dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp", "tp"], + ): + # dp_shard_mod_ep is needed even if it's 1, whose FSDP wrapping + # helps the MoE layers do mixed precision training + if d > 1 or name == "dp_shard_mod_ep": + dims.append(d) + names.append(name) + + logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") + mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + + # Create all the submesh here to ensure all required process groups are + # initialized: + # Mesh for data loading (no communication on this mesh) + dp_mesh_dim_names = [] + # Mesh for param sharding + dp_shard_cp_mesh_dim_names = [] + # Mesh for loss all-reduce + dp_cp_mesh_dim_names = [] + # Mesh for ep + ep_mesh_dim_names = [] + + if self.dp_replicate_enabled: + dp_mesh_dim_names.append("dp_replicate") + dp_cp_mesh_dim_names.append("dp_replicate") + # dp_shard_mod_ep is always needed, even if it's 1 + dp_mesh_dim_names.append("dp_shard_mod_ep") + dp_shard_cp_mesh_dim_names.append("dp_shard_mod_ep") + dp_cp_mesh_dim_names.append("dp_shard_mod_ep") + if "dp_shard_in_ep" in names: + dp_mesh_dim_names.append("dp_shard_in_ep") + dp_shard_cp_mesh_dim_names.append("dp_shard_in_ep") + dp_cp_mesh_dim_names.append("dp_shard_in_ep") + ep_mesh_dim_names.append("dp_shard_in_ep") + if self.cp_enabled: + dp_shard_cp_mesh_dim_names.append("cp") + dp_cp_mesh_dim_names.append("cp") + ep_mesh_dim_names.append("cp") + if self.etp == 1 and self.tp_enabled: + ep_mesh_dim_names.append("tp") + + mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") + mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_shard_cp") + mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") + mesh[tuple(ep_mesh_dim_names)]._flatten(mesh_dim_name="ep") + + return mesh + + def _build_mesh_without_ep(self) -> DeviceMesh: + dims = [] + names = [] + for d, name in zip( + [self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp], + ["pp", "dp_replicate", "dp_shard", "cp", "tp"], + ): + if d > 1: + dims.append(d) + names.append(name) + + logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") + mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + + # Create all the submesh here to ensure all required process groups are + # initialized: + # Mesh for data loading (no communication on this mesh) + dp_mesh_dim_names = [] + # Mesh for param sharding + dp_shard_cp_mesh_dim_names = [] + # Mesh for loss all-reduce + dp_cp_mesh_dim_names = [] + + if self.dp_replicate_enabled: + dp_mesh_dim_names.append("dp_replicate") + dp_cp_mesh_dim_names.append("dp_replicate") + if self.dp_shard_enabled: + dp_mesh_dim_names.append("dp_shard") + dp_shard_cp_mesh_dim_names.append("dp_shard") + dp_cp_mesh_dim_names.append("dp_shard") + if self.cp_enabled: + dp_shard_cp_mesh_dim_names.append("cp") + dp_cp_mesh_dim_names.append("cp") + + if dp_mesh_dim_names != []: + mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") + if dp_shard_cp_mesh_dim_names != []: + mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten( + mesh_dim_name="dp_shard_cp" + ) + if dp_cp_mesh_dim_names != []: + mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") + + return mesh + """ + @property def world_mesh(self) -> "ParallelDims": # This is a hack to make ParallelDims behave like a DeviceMesh. From 3f4181e149e58ec459f80b06f1e21155829774f5 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 28 Aug 2025 23:04:11 -0700 Subject: [PATCH 5/8] misc --- torchtitan/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 52489eabc..bcc7ba858 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -125,7 +125,7 @@ def __init__(self, job_config: JobConfig): # (mainly for debugging, expect perf loss). """ dist_utils.set_determinism( - world_mesh._world_mesh, + world_mesh, self.device, job_config.training.seed, job_config.training.deterministic, From 70be316d4516526b8d4bfa11202ef762da604a36 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 28 Aug 2025 23:05:02 -0700 Subject: [PATCH 6/8] lint --- torchtitan/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index bcc7ba858..4d334b02a 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -11,9 +11,9 @@ from typing import Any, Generator, Iterable, Optional import torch +from torch.distributed.elastic.multiprocessing.errors import record import torchtitan.protocols.train_spec as train_spec_module -from torch.distributed.elastic.multiprocessing.errors import record from torchtitan.components.checkpoint import CheckpointManager from torchtitan.components.dataloader import DataloaderExhaustedError from torchtitan.components.ft import FTManager, maybe_semi_sync_training From 37161358eb9b4eae8086f3c276d6702ced898b94 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 28 Aug 2025 23:31:59 -0700 Subject: [PATCH 7/8] misc --- torchtitan/distributed/parallel_dims.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index c06500e26..685d38445 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -85,7 +85,7 @@ def build_mesh(self) -> "ParallelDims": cp: For CP. tp: For TP. ep: For EP. - dp_shard_mod_ep: For FSDP or HSDP shard dimension in EP region. + dp_shard_in_ep: For FSDP or HSDP shard dimension in EP region. Note: These dimensions won't exist at the same time. If we consider unflatten() operator only, following are all the meshes required @@ -95,8 +95,8 @@ def build_mesh(self) -> "ParallelDims": doesn't need it. ["dp_cp", "tp"]: loss computation ["dp_replicate", "dp_shard_cp", "tp"]: Non-EP region computation. - ["dp_replicate", "dp_shard_mod_ep", "ep", "tp"]: EP region computation if etp == tp. - ["dp_replicate", "dp_shard_mod_ep", "ep"]: EP region computation if etp == 1. + ["dp_replicate", "dp_shard_in_ep", "ep", "tp"]: EP region computation if etp == tp. + ["dp_replicate", "dp_shard_in_ep", "ep"]: EP region computation if etp == 1. In reality, we don't actually need to create all of these meshes. For example, ``dp_cp`` can be sliced and flattened from ["dp", "cp", "tp"]. @@ -112,7 +112,7 @@ def add_dim(name, degree, config): config["degree"].append(degree) world_mesh = init_device_mesh(device_type, [self.world_size]) - dp_shard_mod_ep = ( + dp_shard_in_ep = ( self.dp_shard * self.cp // self.ep if self.etp == self.tp else self.dp_shard * self.cp * self.tp // self.ep @@ -134,7 +134,7 @@ def add_dim(name, degree, config): add_dim("dp_replicate", self.dp_replicate, ep_computation_dims) if self.dp_shard_enabled: add_dim("dp_shard_cp", self.dp_shard * self.cp, non_ep_computation_dims) - add_dim("dp_shard_mod_ep", dp_shard_mod_ep, ep_computation_dims) + add_dim("dp_shard_in_ep", dp_shard_in_ep, ep_computation_dims) if self.cp_enabled: add_dim("cp", self.cp, data_mesh_dims) From a6078b82468cd3ed7556232cfc0c55d143aceafa Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Fri, 29 Aug 2025 07:22:38 -0700 Subject: [PATCH 8/8] misc --- torchtitan/distributed/parallel_dims.py | 26 ++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 685d38445..e265099db 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -73,27 +73,27 @@ def build_mesh(self) -> "ParallelDims": pp: For PP. dp_replicate: For DDP or HSDP replicate dimension. dp_shard_cp: For FSDP or HSDP shard dimension. This includes - ``cp`` even if ``cp`` is 1, so we just use the name - ``dp_shard_cp``. As a result, we always use the name - ``dp_shard_cp`` and ``dp_shard`` is not created as a - dimension. + ``cp`` even if ``cp`` is 1. As a result, we always + use the name ``dp_shard_cp``, and ``dp_shard`` is not + created as a dimension. dp_cp: This is used by loss all-reduce. It includes ``dp_replicate``, ``dp_shard``, and ``cp`` as all of them are data parallelisms. - dp: This is used by data loading. It includes both ``dp_replicate`` - and ``dp_shard``. - The naming can be confusing; ``batch`` could be a better name. + dp: This is used by data loading to decide the global batch size and + which part of data this raunk should read. This dim includes both + ``dp_replicate`` and ``dp_shard``. + The name is confusing; ``batch`` could be a better name. cp: For CP. tp: For TP. ep: For EP. - dp_shard_in_ep: For FSDP or HSDP shard dimension in EP region. + dp_shard_in_ep: For FSDP or HSDP shard dimension in the EP region. Note: These dimensions won't exist at the same time. If we consider - unflatten() operator only, following are all the meshes required + the unflatten() operator only, the following are all the meshes required assuming all degrees are > 1 except for ``pp``: - ["dp", "cp", "tp"]: ``dp`` process group is wasted as dataloader - doesn't need it. - ["dp_cp", "tp"]: loss computation + ["dp", "cp", "tp"]: The ``dp`` process group is wasted as the dataloader + doesn't need it for communication. + ["dp_cp", "tp"]: Loss computation. ["dp_replicate", "dp_shard_cp", "tp"]: Non-EP region computation. ["dp_replicate", "dp_shard_in_ep", "ep", "tp"]: EP region computation if etp == tp. ["dp_replicate", "dp_shard_in_ep", "ep"]: EP region computation if etp == 1. @@ -102,7 +102,7 @@ def build_mesh(self) -> "ParallelDims": For example, ``dp_cp`` can be sliced and flattened from ["dp", "cp", "tp"]. So we don't actually need to create ["dp_cp", "tp"]. - But there are some meses we MUST create if that mesh will be used for a + But there are some meshes we MUST create if that mesh will be used for a parameter. So Non-EP-region-computation mesh and EP-region-computation mesh are required. """