Skip to content

Commit 7bc866f

Browse files
committed
pipelining tutorials
1 parent f9057e4 commit 7bc866f

File tree

1 file changed

+206
-0
lines changed

1 file changed

+206
-0
lines changed
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
torch.distributed.pipelining tutorial
2+
=================================================
3+
**Author**: `Howard Huang <https://github.com/H-Huang>`_
4+
5+
.. note::
6+
|edit| View and edit this tutorial in `github <https://github.com/pytorch/tutorials/blob/main/intermediate_source/pipelining_tutorial.rst>`__.
7+
8+
Prerequisites:
9+
10+
- `PyTorch Distributed Overview <../beginner/dist_overview.html>`__
11+
12+
This tutorial uses a transformer model to demonstrate implementing distributed
13+
pipeline parallelism with `torch.distributed.pipelining <https://pytorch.org/docs/main/distributed.pipelining.html>`__
14+
APIs.
15+
16+
Setup
17+
------
18+
19+
With `torch.distributed.pipelining` we will be paritioning the execution of a model and scheduling computation on micro-batches. We will be using a simplified version
20+
of a transformer decoder model. The model architecture is for educational purposes and has multiple layers as we want to demonstrate how to split the model into different
21+
chunks.
22+
23+
.. code:: python
24+
25+
import torch
26+
import torch.nn as nn
27+
from dataclasses import dataclass
28+
29+
@dataclass
30+
class ModelArgs:
31+
dim: int = 512
32+
n_layers: int = 8
33+
n_heads: int = 8
34+
vocab_size: int = 10000
35+
36+
class Transformer(nn.Module):
37+
def __init__(self, model_args: ModelArgs):
38+
super().__init__()
39+
40+
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
41+
42+
# Using a ModuleDict lets us delete layers witout affecting names,
43+
# ensuring checkpoints will correctly save and load.
44+
self.layers = torch.nn.ModuleDict()
45+
for layer_id in range(model_args.n_layers):
46+
self.layers[str(layer_id)] = nn.TransformerDecoderLayer(model_args.dim, model_args.n_heads)
47+
48+
self.norm = nn.LayerNorm(model_args.dim)
49+
self.output = nn.Linear(model_args.dim, model_args.vocab_size)
50+
51+
def forward(self, tokens: torch.Tensor):
52+
# Handling layers being 'None' at runtime enables easy pipeline splitting
53+
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
54+
55+
for layer in self.layers.values():
56+
h = layer(h, h)
57+
58+
h = self.norm(h) if self.norm else h
59+
output = self.output(h).float() if self.output else h
60+
return output
61+
62+
We first need to import the necessary libraries in our script and initialize the distributed training process. In this case, we are defining some global variables to use
63+
later in the script.
64+
65+
.. code:: python
66+
67+
import os
68+
import torch.distributed as dist
69+
from torch.distributed.pipelining import pipeline, SplitPoint, PipelineStage, ScheduleGPipe
70+
71+
global rank, device, pp_group, stage_index, num_stages
72+
def init_distributed():
73+
global rank, device, pp_group, stage_index, num_stages
74+
rank = int(os.environ["LOCAL_RANK"])
75+
world_size = int(os.environ["WORLD_SIZE"])
76+
device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
77+
dist.init_process_group()
78+
79+
pp_group = dist.new_group()
80+
stage_index = rank
81+
num_stages = world_size
82+
83+
The `rank`, `world_size`, and `init_process_group()` code should seem familiar to you as those are common amongst
84+
all distributed programs. The pipeline parallelism specific globals include the `pp_group` which is the process
85+
group that will be used for send/recv communications, the `stage_index` which, in this example, there is a single rank
86+
per stage so the index is equivalent to the rank, and the `num_stages` which is equivalent to world_size for the same reason.
87+
88+
89+
Step 1: Partition the Transformer Model
90+
---------------------------------------
91+
92+
There are two different ways of paritioning the model:
93+
94+
First is the manual mode in which we can manually create two instances of the model by deleting portions of
95+
attributes of the model. In this example for a 2 stage (2 rank) the model is cut in half.
96+
97+
.. code:: python
98+
99+
def manual_model_split(model, example_input_microbatch, model_args) -> PipelineStage:
100+
if stage_index == 0:
101+
# prepare the first stage model
102+
for i in range(4, 8):
103+
del model.layers[str(i)]
104+
model.norm = None
105+
model.output = None
106+
stage_input_microbatch = example_input_microbatch
107+
108+
elif stage_index == 1:
109+
# prepare the second stage model
110+
for i in range(4):
111+
del model.layers[str(i)]
112+
model.tok_embeddings = None
113+
stage_input_microbatch = torch.randn(example_input_microbatch.shape[0], model_args.dim)
114+
115+
stage = PipelineStage(
116+
model,
117+
stage_index,
118+
num_stages,
119+
device,
120+
input_args=stage_input_microbatch,
121+
)
122+
return stage
123+
124+
As we can see the first stage does not have the layer norm or the output layer, and also only includes the first 4 transformer blocks.
125+
The second stage does not have the input embedding layers, but has the output layers and the final 4 transformer blocks. The function
126+
then returns the `PipelineStage` respective of the current rank.
127+
128+
The second method is the tracer based mode which automatically splits the model based on a splitting specification, using
129+
pipeline specification we can tell `torch.distributed.pipelining` at what point the model should be split. The respective
130+
sections of the model can then be retrieved by passing in the stage_index.
131+
132+
.. code:: python
133+
def tracer_model_split(model, example_input_microbatch) -> PipelineStage:
134+
pipe = pipeline(
135+
module=model,
136+
mb_args=(example_input_microbatch,),
137+
split_spec={
138+
"layers.4": SplitPoint.BEGINNING,
139+
}
140+
)
141+
stage = pipe.build_stage(stage_index, device, pp_group)
142+
return stage
143+
144+
145+
Step 2: Define The Main Execution
146+
--------------------------------
147+
148+
In the main function we will create a particular pipeline schedule that the stages should adhere to. `torch.distributed.pipelining`
149+
supports multiple schedules including single stage per rank schedules GPipe and 1F1B and multiple stage per rank schedules such as
150+
Interleaved1F1B and LoopedBFS.
151+
152+
.. code:: python
153+
154+
if __name__ == "__main__":
155+
init_distributed()
156+
num_microbatches = 4
157+
model_args = ModelArgs()
158+
model = Transformer(model_args)
159+
160+
# Dummy data
161+
x = torch.ones(32, 500, dtype=torch.long)
162+
y = torch.randint(0, model_args.vocab_size, (32, 500), dtype=torch.long)
163+
example_input_microbatch = x.chunk(num_microbatches)[0]
164+
165+
# Option 1: Manual model splitting
166+
stage = manual_model_split(model, example_input_microbatch, model_args)
167+
168+
# Option 2: Tracer model splitting
169+
# stage = tracer_model_split(model, example_input_microbatch)
170+
171+
x = x.to(device)
172+
y = y.to(device)
173+
174+
def tokenwise_loss_fn(outputs, targets):
175+
loss_fn = nn.CrossEntropyLoss()
176+
outputs = outputs.view(-1, model_args.vocab_size)
177+
targets = targets.view(-1)
178+
return loss_fn(outputs, targets)
179+
180+
schedule = ScheduleGPipe(stage, n_microbatches=num_microbatches, loss_fn=token_loss_fn)
181+
182+
if rank == 0:
183+
schedule.step(x)
184+
elif rank == 1:
185+
losses = []
186+
output = schedule.step(target=y, losses=losses)
187+
dist.destroy_process_group()
188+
189+
We are using the manual option of splitting the model, but the code can be uncommented to also try the
190+
tracer based model splitting function. In our schedule we need to pass in the number of microbatches and
191+
the loss function we are using to evaluate the targets.
192+
193+
The `.step()` function runs the entire minibatch and performs automatic splitting into microbatches based
194+
on the `n_microbatches` passed previously. The microbatches are then operated on according to the schedule class.
195+
In this example we are using GPipe, so it is a simple all forwards and then all backwards schedule. The output
196+
returned in rank 1 will be the same as if the model was on a single GPU and run with the entire batch. Similarly,
197+
we can pass in a `losses` container to store the corresponding losses for each microbatch.
198+
199+
Step 3: Launch the Distributed Processes
200+
----------------------------
201+
202+
Finally, we are ready to run the script. We will use `torchrun` to create a single host, 2 process job.
203+
Our script is already written in a way rank 0 will contain the information for pipeline stage 0 and rank 1
204+
will contain the information for pipeline stage 1.
205+
206+
`torchrun --standalone --nnodes 1 --nproc_per_node 2 pipelining_tutorial.py`

0 commit comments

Comments
 (0)