|
| 1 | +Using CommDebugMode |
| 2 | +===================================================== |
| 3 | + |
| 4 | +**Author**: `Anshul Sinha <https://github.com/sinhaanshul>`__ |
| 5 | + |
| 6 | +Prerequisites: |
| 7 | + |
| 8 | +- `Distributed Communication Package - torch.distributed <https://pytorch.org/docs/stable/distributed.html>`__ |
| 9 | +- Python 3.8 - 3.11 |
| 10 | +- PyTorch 2.2 |
| 11 | + |
| 12 | + |
| 13 | +What is CommDebugMode and why is it useful |
| 14 | +------------------ |
| 15 | +As the size of models continues to increase, users are seeking to leverage various combinations of parallel strategies to scale up distributed training. However, the lack of interoperability between existing solutions poses a significant challenge, primarily due to the absence of a unified abstraction that can bridge these different parallelism strategies. To address this issue, PyTorch has proposed DistributedTensor (DTensor)which abstracts away the complexities of tensor communication in distributed training, providing a seamless user experience. However, this abstraction creates a lack of transparency that can make it challenging for users to identify and resolve issues. To address this challenge, my internship project aims to develop and enhance CommDebugMode, a Python context manager that will serve as one of the primary debugging tools for DTensors. CommDebugMode is a python context manager that enables users to view when and why collective operations are happening when using DTensors, addressing this problem. |
| 16 | + |
| 17 | + |
| 18 | +Why DeviceMesh is Useful |
| 19 | +------------------------ |
| 20 | +DeviceMesh is useful when working with multi-dimensional parallelism (i.e. 3-D parallel) where parallelism composability is required. For example, when your parallelism solutions require both communication across hosts and within each host. |
| 21 | +The image above shows that we can create a 2D mesh that connects the devices within each host, and connects each device with its counterpart on the other hosts in a homogenous setup. |
| 22 | + |
| 23 | +Without DeviceMesh, users would need to manually set up NCCL communicators, cuda devices on each process before applying any parallelism, which could be quite complicated. |
| 24 | +The following code snippet illustrates a hybrid sharding 2-D Parallel pattern setup without :class:`DeviceMesh`. |
| 25 | +First, we need to manually calculate the shard group and replicate group. Then, we need to assign the correct shard and |
| 26 | +replicate group to each rank. |
| 27 | + |
| 28 | +.. code-block:: python |
| 29 | +
|
| 30 | + import os |
| 31 | +
|
| 32 | + import torch |
| 33 | + import torch.distributed as dist |
| 34 | +
|
| 35 | + # Understand world topology |
| 36 | + rank = int(os.environ["RANK"]) |
| 37 | + world_size = int(os.environ["WORLD_SIZE"]) |
| 38 | + print(f"Running example on {rank=} in a world with {world_size=}") |
| 39 | +
|
| 40 | + # Create process groups to manage 2-D like parallel pattern |
| 41 | + dist.init_process_group("nccl") |
| 42 | + torch.cuda.set_device(rank) |
| 43 | +
|
| 44 | + # Create shard groups (e.g. (0, 1, 2, 3), (4, 5, 6, 7)) |
| 45 | + # and assign the correct shard group to each rank |
| 46 | + num_node_devices = torch.cuda.device_count() |
| 47 | + shard_rank_lists = list(range(0, num_node_devices // 2)), list(range(num_node_devices // 2, num_node_devices)) |
| 48 | + shard_groups = ( |
| 49 | + dist.new_group(shard_rank_lists[0]), |
| 50 | + dist.new_group(shard_rank_lists[1]), |
| 51 | + ) |
| 52 | + current_shard_group = ( |
| 53 | + shard_groups[0] if rank in shard_rank_lists[0] else shard_groups[1] |
| 54 | + ) |
| 55 | +
|
| 56 | + # Create replicate groups (for example, (0, 4), (1, 5), (2, 6), (3, 7)) |
| 57 | + # and assign the correct replicate group to each rank |
| 58 | + current_replicate_group = None |
| 59 | + shard_factor = len(shard_rank_lists[0]) |
| 60 | + for i in range(num_node_devices // 2): |
| 61 | + replicate_group_ranks = list(range(i, num_node_devices, shard_factor)) |
| 62 | + replicate_group = dist.new_group(replicate_group_ranks) |
| 63 | + if rank in replicate_group_ranks: |
| 64 | + current_replicate_group = replicate_group |
| 65 | +
|
| 66 | +To run the above code snippet, we can leverage PyTorch Elastic. Let's create a file named ``2d_setup.py``. |
| 67 | +Then, run the following `torch elastic/torchrun <https://pytorch.org/docs/stable/elastic/quickstart.html>`__ command. |
| 68 | + |
| 69 | +.. code-block:: python |
| 70 | +
|
| 71 | + torchrun --nproc_per_node=8 --rdzv_id=100 --rdzv_endpoint=localhost:29400 2d_setup.py |
| 72 | +
|
| 73 | +.. note:: |
| 74 | + For simplicity of demonstration, we are simulating 2D parallel using only one node. Note that this code snippet can also be used when running on multi hosts setup. |
| 75 | + |
| 76 | +With the help of :func:`init_device_mesh`, we can accomplish the above 2D setup in just two lines, and we can still |
| 77 | +access the underlying :class:`ProcessGroup` if needed. |
| 78 | + |
| 79 | + |
| 80 | +.. code-block:: python |
| 81 | +
|
| 82 | + from torch.distributed.device_mesh import init_device_mesh |
| 83 | + mesh_2d = init_device_mesh("cuda", (2, 4), mesh_dim_names=("replicate", "shard")) |
| 84 | +
|
| 85 | + # Users can access the underlying process group thru `get_group` API. |
| 86 | + replicate_group = mesh_2d.get_group(mesh_dim="replicate") |
| 87 | + shard_group = mesh_2d.get_group(mesh_dim="shard") |
| 88 | +
|
| 89 | +Let's create a file named ``2d_setup_with_device_mesh.py``. |
| 90 | +Then, run the following `torch elastic/torchrun <https://pytorch.org/docs/stable/elastic/quickstart.html>`__ command. |
| 91 | + |
| 92 | +.. code-block:: python |
| 93 | +
|
| 94 | + torchrun --nproc_per_node=8 2d_setup_with_device_mesh.py |
| 95 | +
|
| 96 | +
|
| 97 | +How to use DeviceMesh with HSDP |
| 98 | +------------------------------- |
| 99 | + |
| 100 | +Hybrid Sharding Data Parallel(HSDP) is 2D strategy to perform FSDP within a host and DDP across hosts. |
| 101 | + |
| 102 | +Let's see an example of how DeviceMesh can assist with applying HSDP to your model with a simple setup. With DeviceMesh, |
| 103 | +users would not need to manually create and manage shard group and replicate group. |
| 104 | + |
| 105 | +.. code-block:: python |
| 106 | +
|
| 107 | + import torch |
| 108 | + import torch.nn as nn |
| 109 | +
|
| 110 | + from torch.distributed.device_mesh import init_device_mesh |
| 111 | + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy |
| 112 | +
|
| 113 | +
|
| 114 | + class ToyModel(nn.Module): |
| 115 | + def __init__(self): |
| 116 | + super(ToyModel, self).__init__() |
| 117 | + self.net1 = nn.Linear(10, 10) |
| 118 | + self.relu = nn.ReLU() |
| 119 | + self.net2 = nn.Linear(10, 5) |
| 120 | +
|
| 121 | + def forward(self, x): |
| 122 | + return self.net2(self.relu(self.net1(x))) |
| 123 | +
|
| 124 | +
|
| 125 | + # HSDP: MeshShape(2, 4) |
| 126 | + mesh_2d = init_device_mesh("cuda", (2, 4)) |
| 127 | + model = FSDP( |
| 128 | + ToyModel(), device_mesh=mesh_2d, sharding_strategy=ShardingStrategy.HYBRID_SHARD |
| 129 | + ) |
| 130 | +
|
| 131 | +Let's create a file named ``hsdp.py``. |
| 132 | +Then, run the following `torch elastic/torchrun <https://pytorch.org/docs/stable/elastic/quickstart.html>`__ command. |
| 133 | + |
| 134 | +.. code-block:: python |
| 135 | +
|
| 136 | + torchrun --nproc_per_node=8 hsdp.py |
| 137 | +
|
| 138 | +How to use DeviceMesh for your custom parallel solutions |
| 139 | +-------------------------------------------------------- |
| 140 | +When working with large scale training, you might have more complex custom parallel training composition. For example, you may need to slice out submeshes for different parallelism solutions. |
| 141 | +DeviceMesh allows users to slice child mesh from the parent mesh and re-use the NCCL communicators already created when the parent mesh is initialized. |
| 142 | + |
| 143 | +.. code-block:: python |
| 144 | +
|
| 145 | + from torch.distributed.device_mesh import init_device_mesh |
| 146 | + mesh_3d = init_device_mesh("cuda", (2, 2, 2), mesh_dim_names=("replicate", "shard", "tp")) |
| 147 | +
|
| 148 | + # Users can slice child meshes from the parent mesh. |
| 149 | + hsdp_mesh = mesh_3d["replicate", "shard"] |
| 150 | + tp_mesh = mesh_3d["tp"] |
| 151 | +
|
| 152 | + # Users can access the underlying process group thru `get_group` API. |
| 153 | + replicate_group = hsdp_mesh["replicate"].get_group() |
| 154 | + shard_group = hsdp_mesh["Shard"].get_group() |
| 155 | + tp_group = tp_mesh.get_group() |
| 156 | +
|
| 157 | +
|
| 158 | +Conclusion |
| 159 | +---------- |
| 160 | +In conclusion, we have learned about :class:`DeviceMesh` and :func:`init_device_mesh`, as well as how |
| 161 | +they can be used to describe the layout of devices across the cluster. |
| 162 | + |
| 163 | +For more information, please see the following: |
| 164 | + |
| 165 | +- `2D parallel combining Tensor/Sequance Parallel with FSDP <https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/fsdp_tp_example.py>`__ |
| 166 | +- `Composable PyTorch Distributed with PT2 <https://static.sched.com/hosted_files/pytorch2023/d1/%5BPTC%2023%5D%20Composable%20PyTorch%20Distributed%20with%20PT2.pdf>`__ |
0 commit comments