|
| 1 | +Introduction to Distributed Pipeline Parallelism |
| 2 | +================================================ |
| 3 | +**Authors**: `Howard Huang <https://github.com/H-Huang>`_ |
| 4 | + |
| 5 | +.. note:: |
| 6 | + |edit| View and edit this tutorial in `github <https://github.com/pytorch/tutorials/blob/main/intermediate_source/pipelining_tutorial.rst>`__. |
| 7 | + |
| 8 | +Prerequisites |
| 9 | +------------- |
| 10 | + |
| 11 | +- `PyTorch Distributed Overview <../beginner/dist_overview.html>`__ |
| 12 | + |
| 13 | +This tutorial uses a gpt-style transformer model to demonstrate implementing distributed |
| 14 | +pipeline parallelism with `torch.distributed.pipelining <https://pytorch.org/docs/main/distributed.pipelining.html>`__ |
| 15 | +APIs. |
| 16 | + |
| 17 | +Setup |
| 18 | +----- |
| 19 | + |
| 20 | +With `torch.distributed.pipelining` we will be paritioning the execution of a model and scheduling computation on micro-batches. We will be using a simplified version |
| 21 | +of a transformer decoder model. The model architecture is for educational purposes and has multiple transformer decoder layers as we want to demonstrate how to split the model into different |
| 22 | +chunks. First, let us define the model: |
| 23 | + |
| 24 | +.. code:: python |
| 25 | +
|
| 26 | + import torch |
| 27 | + import torch.nn as nn |
| 28 | + from dataclasses import dataclass |
| 29 | +
|
| 30 | + @dataclass |
| 31 | + class ModelArgs: |
| 32 | + dim: int = 512 |
| 33 | + n_layers: int = 8 |
| 34 | + n_heads: int = 8 |
| 35 | + vocab_size: int = 10000 |
| 36 | +
|
| 37 | + class Transformer(nn.Module): |
| 38 | + def __init__(self, model_args: ModelArgs): |
| 39 | + super().__init__() |
| 40 | +
|
| 41 | + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) |
| 42 | +
|
| 43 | + # Using a ModuleDict lets us delete layers witout affecting names, |
| 44 | + # ensuring checkpoints will correctly save and load. |
| 45 | + self.layers = torch.nn.ModuleDict() |
| 46 | + for layer_id in range(model_args.n_layers): |
| 47 | + self.layers[str(layer_id)] = nn.TransformerDecoderLayer(model_args.dim, model_args.n_heads) |
| 48 | +
|
| 49 | + self.norm = nn.LayerNorm(model_args.dim) |
| 50 | + self.output = nn.Linear(model_args.dim, model_args.vocab_size) |
| 51 | +
|
| 52 | + def forward(self, tokens: torch.Tensor): |
| 53 | + # Handling layers being 'None' at runtime enables easy pipeline splitting |
| 54 | + h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens |
| 55 | +
|
| 56 | + for layer in self.layers.values(): |
| 57 | + h = layer(h, h) |
| 58 | +
|
| 59 | + h = self.norm(h) if self.norm else h |
| 60 | + output = self.output(h).float() if self.output else h |
| 61 | + return output |
| 62 | +
|
| 63 | +Then, we need to import the necessary libraries in our script and initialize the distributed training process. In this case, we are defining some global variables to use |
| 64 | +later in the script: |
| 65 | + |
| 66 | +.. code:: python |
| 67 | +
|
| 68 | + import os |
| 69 | + import torch.distributed as dist |
| 70 | + from torch.distributed.pipelining import pipeline, SplitPoint, PipelineStage, ScheduleGPipe |
| 71 | +
|
| 72 | + global rank, device, pp_group, stage_index, num_stages |
| 73 | + def init_distributed(): |
| 74 | + global rank, device, pp_group, stage_index, num_stages |
| 75 | + rank = int(os.environ["LOCAL_RANK"]) |
| 76 | + world_size = int(os.environ["WORLD_SIZE"]) |
| 77 | + device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu") |
| 78 | + dist.init_process_group() |
| 79 | +
|
| 80 | + pp_group = dist.new_group() |
| 81 | + stage_index = rank |
| 82 | + num_stages = world_size |
| 83 | +
|
| 84 | +The `rank`, `world_size`, and `init_process_group()` code should seem familiar to you as those are common amongst |
| 85 | +all distributed programs. The pipeline parallelism specific globals include the `pp_group` which is the process |
| 86 | +group that will be used for send/recv communications, the `stage_index` which, in this example, is a single rank |
| 87 | +per stage so the index is equivalent to the rank, and the `num_stages` which is equivalent to world_size. |
| 88 | + |
| 89 | +The `num_stages` is used to set the number of stages that will be used in the pipeline parallelism schedule. For example, |
| 90 | +for `num_stages=4`, a microbatch will need to go through 4 forwards and 4 backwards before it is completed. The `stage_index` |
| 91 | +is necessary for the framework to know how to communicate between stages. For example, for the first stage (`stage_index=0`), it will |
| 92 | +use data from the dataloader and does not need to receive data from any previous peers to perform its computation. |
| 93 | + |
| 94 | + |
| 95 | +Step 1: Partition the Transformer Model |
| 96 | +--------------------------------------- |
| 97 | + |
| 98 | +There are two different ways of partitioning the model: |
| 99 | + |
| 100 | +First is the manual mode in which we can manually create two instances of the model by deleting portions of |
| 101 | +attributes of the model. In this example for a 2 stage (2 ranks) the model is cut in half. |
| 102 | + |
| 103 | +.. code:: python |
| 104 | +
|
| 105 | + def manual_model_split(model, example_input_microbatch, model_args) -> PipelineStage: |
| 106 | + if stage_index == 0: |
| 107 | + # prepare the first stage model |
| 108 | + for i in range(4, 8): |
| 109 | + del model.layers[str(i)] |
| 110 | + model.norm = None |
| 111 | + model.output = None |
| 112 | + stage_input_microbatch = example_input_microbatch |
| 113 | +
|
| 114 | + elif stage_index == 1: |
| 115 | + # prepare the second stage model |
| 116 | + for i in range(4): |
| 117 | + del model.layers[str(i)] |
| 118 | + model.tok_embeddings = None |
| 119 | + stage_input_microbatch = torch.randn(example_input_microbatch.shape[0], model_args.dim) |
| 120 | +
|
| 121 | + stage = PipelineStage( |
| 122 | + model, |
| 123 | + stage_index, |
| 124 | + num_stages, |
| 125 | + device, |
| 126 | + input_args=stage_input_microbatch, |
| 127 | + ) |
| 128 | + return stage |
| 129 | +
|
| 130 | +As we can see the first stage does not have the layer norm or the output layer, and also only includes the first 4 transformer blocks. |
| 131 | +The second stage does not have the input embedding layers, but has the output layers and the final 4 transformer blocks. The function |
| 132 | +then returns the `PipelineStage` for the current rank. |
| 133 | + |
| 134 | +The second method is the tracer based mode which automatically splits the model based on a splitting specification, using |
| 135 | +pipeline specification we can tell `torch.distributed.pipelining` at what point the model should be split. In the code block, |
| 136 | +we are splitting before the before 4th transformer decoder layer, which is the same as what was done manually above. Similarly, |
| 137 | +we can retrieve a `PipelineStage` by calling `build_stage` after this splitting is done. |
| 138 | + |
| 139 | +.. code:: python |
| 140 | + def tracer_model_split(model, example_input_microbatch) -> PipelineStage: |
| 141 | + pipe = pipeline( |
| 142 | + module=model, |
| 143 | + mb_args=(example_input_microbatch,), |
| 144 | + split_spec={ |
| 145 | + "layers.4": SplitPoint.BEGINNING, |
| 146 | + } |
| 147 | + ) |
| 148 | + stage = pipe.build_stage(stage_index, device, pp_group) |
| 149 | + return stage |
| 150 | +
|
| 151 | +
|
| 152 | +Step 2: Define The Main Execution |
| 153 | +--------------------------------- |
| 154 | + |
| 155 | +In the main function we will create a particular pipeline schedule that the stages should adhere to. `torch.distributed.pipelining` |
| 156 | +supports multiple schedules including single stage per rank schedules GPipe and 1F1B and multiple stage per rank schedules such as |
| 157 | +Interleaved1F1B and LoopedBFS. |
| 158 | + |
| 159 | +.. code:: python |
| 160 | +
|
| 161 | + if __name__ == "__main__": |
| 162 | + init_distributed() |
| 163 | + num_microbatches = 4 |
| 164 | + model_args = ModelArgs() |
| 165 | + model = Transformer(model_args) |
| 166 | +
|
| 167 | + # Dummy data |
| 168 | + x = torch.ones(32, 500, dtype=torch.long) |
| 169 | + y = torch.randint(0, model_args.vocab_size, (32, 500), dtype=torch.long) |
| 170 | + example_input_microbatch = x.chunk(num_microbatches)[0] |
| 171 | +
|
| 172 | + # Option 1: Manual model splitting |
| 173 | + stage = manual_model_split(model, example_input_microbatch, model_args) |
| 174 | +
|
| 175 | + # Option 2: Tracer model splitting |
| 176 | + # stage = tracer_model_split(model, example_input_microbatch) |
| 177 | +
|
| 178 | + x = x.to(device) |
| 179 | + y = y.to(device) |
| 180 | +
|
| 181 | + def tokenwise_loss_fn(outputs, targets): |
| 182 | + loss_fn = nn.CrossEntropyLoss() |
| 183 | + outputs = outputs.view(-1, model_args.vocab_size) |
| 184 | + targets = targets.view(-1) |
| 185 | + return loss_fn(outputs, targets) |
| 186 | +
|
| 187 | + schedule = ScheduleGPipe(stage, n_microbatches=num_microbatches, loss_fn=token_loss_fn) |
| 188 | +
|
| 189 | + if rank == 0: |
| 190 | + schedule.step(x) |
| 191 | + elif rank == 1: |
| 192 | + losses = [] |
| 193 | + output = schedule.step(target=y, losses=losses) |
| 194 | + dist.destroy_process_group() |
| 195 | +
|
| 196 | +We are using the manual option of splitting the model, but the code can be uncommented to also try the |
| 197 | +tracer based model splitting function. In our schedule we need to pass in the number of microbatches and |
| 198 | +the loss function we are using to evaluate the targets. |
| 199 | + |
| 200 | +The `.step()` function runs the entire minibatch and performs automatic splitting into microbatches based |
| 201 | +on the `n_microbatches` passed previously. The microbatches are then operated on according to the schedule class. |
| 202 | +In this example we are using GPipe, so it is a simple all forwards and then all backwards schedule. The output |
| 203 | +returned from rank 1 will be the same as if the model was on a single GPU and run with the entire batch. Similarly, |
| 204 | +we can pass in a `losses` container to store the corresponding losses for each microbatch. |
| 205 | + |
| 206 | +Step 3: Launch the Distributed Processes |
| 207 | +---------------------------------------- |
| 208 | + |
| 209 | +Finally, we are ready to run the script. We will use `torchrun` to create a single host, 2 process job. |
| 210 | +Our script is already written in a way rank 0 will performs the required logic for pipeline stage 0 and rank 1 |
| 211 | +will performs the logic for pipeline stage 1. |
| 212 | + |
| 213 | +`torchrun --standalone --nnodes 1 --nproc_per_node 2 pipelining_tutorial.py` |
0 commit comments