diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index dbb443c6b..e265099db 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from collections import defaultdict from dataclasses import dataclass from torch.distributed.device_mesh import DeviceMesh, init_device_mesh @@ -25,6 +26,7 @@ class ParallelDims: ep: int etp: int world_size: int + mesh_dim_names: tuple[str] = tuple() _world_mesh: DeviceMesh = None @@ -63,6 +65,139 @@ def _validate(self): # EP would borrow all cp and tp and some dp_shard degree assert ep % (cp * tp) == 0 and (dp_shard * cp * tp) % ep == 0 + def build_mesh(self) -> "ParallelDims": + """Build the device mesh with the required mesh dimensions. + + The following mesh dimensions may be created based on the parallel configuration: + + pp: For PP. + dp_replicate: For DDP or HSDP replicate dimension. + dp_shard_cp: For FSDP or HSDP shard dimension. This includes + ``cp`` even if ``cp`` is 1. As a result, we always + use the name ``dp_shard_cp``, and ``dp_shard`` is not + created as a dimension. + dp_cp: This is used by loss all-reduce. It includes ``dp_replicate``, + ``dp_shard``, and ``cp`` as all of them are data parallelisms. + dp: This is used by data loading to decide the global batch size and + which part of data this raunk should read. This dim includes both + ``dp_replicate`` and ``dp_shard``. + The name is confusing; ``batch`` could be a better name. + cp: For CP. + tp: For TP. + ep: For EP. + dp_shard_in_ep: For FSDP or HSDP shard dimension in the EP region. + + Note: These dimensions won't exist at the same time. If we consider + the unflatten() operator only, the following are all the meshes required + assuming all degrees are > 1 except for ``pp``: + + ["dp", "cp", "tp"]: The ``dp`` process group is wasted as the dataloader + doesn't need it for communication. + ["dp_cp", "tp"]: Loss computation. + ["dp_replicate", "dp_shard_cp", "tp"]: Non-EP region computation. + ["dp_replicate", "dp_shard_in_ep", "ep", "tp"]: EP region computation if etp == tp. + ["dp_replicate", "dp_shard_in_ep", "ep"]: EP region computation if etp == 1. + + In reality, we don't actually need to create all of these meshes. + For example, ``dp_cp`` can be sliced and flattened from ["dp", "cp", "tp"]. + So we don't actually need to create ["dp_cp", "tp"]. + + But there are some meshes we MUST create if that mesh will be used for a + parameter. So Non-EP-region-computation mesh and EP-region-computation mesh + are required. + """ + + def add_dim(name, degree, config): + config["name"].append(name) + config["degree"].append(degree) + + world_mesh = init_device_mesh(device_type, [self.world_size]) + dp_shard_in_ep = ( + self.dp_shard * self.cp // self.ep + if self.etp == self.tp + else self.dp_shard * self.cp * self.tp // self.ep + ) + + data_mesh_dims = defaultdict(list) + non_ep_computation_dims = defaultdict(list) + ep_computation_dims = defaultdict(list) + + if self.pp_enabled: + add_dim("pp", self.pp, data_mesh_dims) + add_dim("pp", self.pp, non_ep_computation_dims) + add_dim("pp", self.pp, ep_computation_dims) + + if self.dp_enabled: + add_dim("dp", self.dp_replicate * self.dp_shard, data_mesh_dims) + if self.dp_replicate_enabled: + add_dim("dp_replicate", self.dp_replicate, non_ep_computation_dims) + add_dim("dp_replicate", self.dp_replicate, ep_computation_dims) + if self.dp_shard_enabled: + add_dim("dp_shard_cp", self.dp_shard * self.cp, non_ep_computation_dims) + add_dim("dp_shard_in_ep", dp_shard_in_ep, ep_computation_dims) + + if self.cp_enabled: + add_dim("cp", self.cp, data_mesh_dims) + + if self.tp_enabled: + add_dim("tp", self.tp, data_mesh_dims, non_ep_computation_dims) + if self.etp == self.tp: + add_dim("tp", self.tp, ep_computation_dims) + + self._all_meshes = [] + + if self.dp_enabled: + data_mesh = world_mesh._unflatten( + 0, data_mesh_dims["degree"], data_mesh_dims["name"] + ) + self._all_meshes.append(data_mesh) + # Note that we don't create loss_mesh as it is easier to flatten + # from data_mesh + if self.cp_enabled: + self._all_meshes[-1]["dp", "cp"]._flatten(mesh_dim_name="dp_cp") + else: + self._all_meshes[-1]["dp"]._flatten(mesh_dim_name="dp_cp") + + if self.dp_cp_enabled or self.tp_enabled or self.pp_enabled: + self._all_meshes.append( + world_mesh._unflatten( + 0, + non_ep_computation_dims["degree"], + non_ep_computation_dims["name"], + ) + ) + + if self.ep_enabled: + add_dim("ep", self.ep, ep_computation_dims) + self._all_meshes.append( + world_mesh._unflatten( + 0, ep_computation_dims["degree"], ep_computation_dims["name"] + ) + ) + + self._world_mesh = world_mesh + self.mesh_dim_names = tuple( + name for m in self._all_meshes for name in m.mesh_dim_names + ) + return self + + def __getitem__(self, name): + # This is a hack to make ParallelDims behave like a DeviceMesh. + # We will need to change trainer if design is concluded. For now, + # this is just a quick hack to make it work with unflatten() + + if "mesh_dim_names" == name: + return [name for m in self._all_meshes for name in m.mesh_dim_names] + + for mesh in self._all_meshes: + try: + submesh = mesh[name] + return submesh + except KeyError: + pass + raise AttributeError(f"ParallelDims has no attribute {name}") + + """ def build_mesh(self) -> DeviceMesh: # TODO: Current implementation of ParallelDims for dp2ep Expert Parallel # is not very clean, due to the limited support from DeviceMesh @@ -188,14 +323,19 @@ def _build_mesh_without_ep(self) -> DeviceMesh: mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") return mesh + """ @property - def world_mesh(self) -> str: + def world_mesh(self) -> "ParallelDims": + # This is a hack to make ParallelDims behave like a DeviceMesh. + # We will need to change trainer if design is concluded. For now, + # this is just a quick hack to make it work with unflatten() + # doing late init so ParallelDims can still be used as a lightweight # dataclass without having to initialize the world mesh if self._world_mesh is None: - self._world_mesh = self.build_mesh() - return self._world_mesh + self.build_mesh() + return self @property def dp_enabled(self): diff --git a/torchtitan/train.py b/torchtitan/train.py index 9b69fd679..4d334b02a 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -123,12 +123,14 @@ def __init__(self, job_config: JobConfig): # Set random seed, and maybe enable deterministic mode # (mainly for debugging, expect perf loss). + """ dist_utils.set_determinism( world_mesh, self.device, job_config.training.seed, job_config.training.deterministic, ) + """ self.train_spec = train_spec_module.get_train_spec(job_config.model.name) # build tokenizer and dataloader @@ -611,7 +613,7 @@ def train(self): timeout=timedelta( seconds=job_config.comm.train_timeout_seconds ), - world_mesh=self.parallel_dims.world_mesh, + world_mesh=self.parallel_dims._world_mesh, ) if torch.distributed.get_rank() == 0: