Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 9dc9eff

Browse files
authored
Add pipeline parallel (#1060)
ghstack-source-id: 02acf73 Pull Request resolved: #1060
1 parent 7191b57 commit 9dc9eff

File tree

2 files changed

+77
-29
lines changed

2 files changed

+77
-29
lines changed

build/model_dist.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,27 +30,35 @@
3030
device_mesh = None
3131

3232

33-
class Transformer(nn.Module):
34-
def __init__(self, config: TransformerArgs) -> None:
33+
class TransformerStage(nn.Module):
34+
def __init__(self, config: TransformerArgs, stage_idx: int, n_stages: int) -> None:
3535
super().__init__()
3636
self.config = config
37+
self.stage_idx = stage_idx
38+
self.n_stages = n_stages
39+
self.layers_per_stage = config.n_layers // n_stages
3740

3841
# Get device mesh
3942
global device_mesh
4043
if device_mesh is None:
4144
device_mesh = _mesh_resources.get_current_mesh()
4245

43-
tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
44-
self.tok_embeddings = parallelize_module(
45-
tok_embeddings,
46-
device_mesh,
47-
RowwiseParallel(input_layouts=Replicate()),
48-
)
49-
self.layers = nn.ModuleList(
50-
TransformerBlock(config) for _ in range(config.n_layers)
51-
)
52-
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
53-
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
46+
if stage_idx == 0:
47+
tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
48+
self.tok_embeddings = parallelize_module(
49+
tok_embeddings,
50+
device_mesh,
51+
RowwiseParallel(input_layouts=Replicate()),
52+
)
53+
54+
# Use ModuleDict so that each layer can be assigned its layer ID in the original model
55+
self.layers = nn.ModuleDict()
56+
for layer_id in range(self.layers_per_stage * stage_idx, self.layers_per_stage * (stage_idx + 1)):
57+
self.layers[str(layer_id)] = TransformerBlock(config)
58+
59+
if stage_idx == n_stages - 1:
60+
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
61+
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
5462

5563
# self.freqs_cis: Optional[Tensor] = None
5664
# self.mask_cache: Optional[Tensor] = None
@@ -67,7 +75,7 @@ def setup_caches(self, max_batch_size, max_seq_length):
6775
max_seq_length = find_multiple(max_seq_length, 8)
6876
self.max_seq_length = max_seq_length
6977
self.max_batch_size = max_batch_size
70-
for b in self.layers:
78+
for b in self.layers.values():
7179
b.attention.kv_cache = KVCache(
7280
max_batch_size, max_seq_length, self.config.n_local_heads, head_dim
7381
)
@@ -84,19 +92,26 @@ def setup_caches(self, max_batch_size, max_seq_length):
8492
)
8593
self.register_buffer("causal_mask", causal_mask, persistent=True)
8694

87-
def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
95+
def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
8896
assert self.freqs_cis is not None, "Caches must be initialized first"
97+
if input_pos is None:
98+
input_pos = torch.arange(x.shape[1], device=x.device, dtype=torch.long)
8999
mask = self.causal_mask[None, None, input_pos]
90100
freqs_cis = self.freqs_cis[input_pos]
91-
x: DTensor = self.tok_embeddings(idx)
92-
# TODO: sequence parallelize this
93101

94-
for _, layer in enumerate(self.layers):
102+
if self.stage_idx == 0:
103+
x: DTensor = self.tok_embeddings(x)
104+
# TODO: sequence parallelize this
105+
106+
for _, layer in self.layers.items():
95107
x = layer(x, input_pos, freqs_cis, mask)
96-
x = self.norm(x)
97-
logits = self.output(x)
98-
# print(f"logits shape: {logits.shape}")
99-
return logits
108+
109+
if self.stage_idx == self.n_stages - 1:
110+
x = self.norm(x)
111+
x = self.output(x)
112+
113+
# print(f"stage output shape: {x.shape}")
114+
return x
100115

101116
@classmethod
102117
def from_name(cls, name: str):

dist_run.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,67 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# Run command:
8+
# torchrun --nproc-per-node 4 dist_run.py
9+
710
import torch
811
import torch.distributed as dist
12+
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
913

1014
from build.model import TransformerArgs
11-
from build.model_dist import Transformer
15+
from build.model_dist import TransformerStage
1216

1317
# Model config
1418
def main():
1519
config = TransformerArgs.from_name("Transformer-2-7b-chat-hf")
1620
print(config)
1721

1822
# Construct a device mesh with available devices (multi-host or single host)
19-
device_mesh = dist.init_device_mesh("cuda", (2,), mesh_dim_names=("tp",))
23+
device_mesh = dist.init_device_mesh("cuda", (2, 2), mesh_dim_names=("pp", "tp"))
24+
tp_mesh = device_mesh["tp"]
25+
pp_mesh = device_mesh["pp"]
26+
pp_rank = pp_mesh.get_local_rank()
27+
nstages = pp_mesh.size()
28+
2029
rank = dist.get_rank()
2130
device = torch.device(f"cuda:{rank}")
2231

2332
# Create parallel model with device_mesh context
2433
with device:
25-
with device_mesh:
26-
model = Transformer(config)
34+
with tp_mesh:
35+
model = TransformerStage(config, pp_rank, nstages)
2736
model.setup_caches(1, 4096)
2837

2938
print(model)
3039

3140
# Distributed run
32-
input_ids = torch.randint(0, config.vocab_size, (1, 4096), device=device)
33-
input_pos = torch.arange(0, 4096, device=device)
34-
output = model(input_ids, input_pos)
41+
mbs = 2 # number of micro-batches
42+
mb_size = 1 # micro-batch size
43+
batch_size = mbs * mb_size # total batch size
44+
seqlen = 4096 # sequence length
45+
dim = 4096 # embedding dimension
46+
47+
# Example input for pipeline stages
48+
mb_ids = torch.randint(0, config.vocab_size, (mb_size, seqlen), device=device)
49+
activation = torch.rand(mb_size, seqlen, dim, device=device)
50+
example_args = mb_ids if pp_rank == 0 else activation
51+
52+
# Create pipeline stages
53+
stage = PipelineStage(
54+
model, pp_rank, nstages, device,
55+
input_args=(example_args,),
56+
group=pp_mesh.get_group(),
57+
)
58+
59+
# Run pipeline
60+
schedule = ScheduleGPipe(stage, mbs)
61+
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), device=device)
62+
if pp_rank == 0:
63+
schedule.step(input_ids)
64+
else:
65+
output = schedule.step()
66+
print(f"{output=}")
67+
3568
dist.destroy_process_group()
3669
print(f"Rank {rank} completes.")
3770

0 commit comments

Comments
 (0)