Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 19a47e7

Browse files
authored
Initial add of distributed model (#1063)
* Initial add of distributed model Use parallelize_module in model [ghstack-poisoned] * Update on "Initial add of distributed model" Use `parallelize_module` in model. Added files: `model_dist.py`: a mirror of model.py with Tensor Parallelism baked in. `dist_run.py`: toy example of how to run the model in distributed way. Test: `torchrun --nproc-per-node 2 dist_run.py` [ghstack-poisoned] * Update on "Initial add of distributed model" Use `parallelize_module` in model. Added files: `model_dist.py`: a mirror of model.py with Tensor Parallelism baked in. `dist_run.py`: toy example of how to run the model in distributed way. Test: `torchrun --nproc-per-node 2 dist_run.py` [ghstack-poisoned]
1 parent 2f4ba2d commit 19a47e7

File tree

2 files changed

+317
-0
lines changed

2 files changed

+317
-0
lines changed

build/model_dist.py

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from pathlib import Path
7+
from typing import Dict, Optional
8+
9+
import torch
10+
import torch.nn as nn
11+
12+
from torch import Tensor
13+
from torch.nn import functional as F
14+
from torch.distributed._tensor import DTensor, Replicate
15+
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
16+
from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel
17+
18+
from build.utils import find_multiple
19+
20+
from build.model import TransformerArgs, KVCache, apply_rotary_emb, precompute_freqs_cis
21+
22+
config_path = Path(f"{str(Path(__file__).parent)}/known_model_params")
23+
24+
25+
# Use DTensor as output, by default
26+
Colwise = ColwiseParallel(use_local_output=False)
27+
Rowwise = RowwiseParallel(use_local_output=False)
28+
29+
# Device mesh context
30+
device_mesh = None
31+
32+
33+
class Transformer(nn.Module):
34+
def __init__(self, config: TransformerArgs) -> None:
35+
super().__init__()
36+
self.config = config
37+
38+
# Get device mesh
39+
global device_mesh
40+
if device_mesh is None:
41+
device_mesh = _mesh_resources.get_current_mesh()
42+
43+
tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
44+
self.tok_embeddings = parallelize_module(
45+
tok_embeddings,
46+
device_mesh,
47+
RowwiseParallel(input_layouts=Replicate()),
48+
)
49+
self.layers = nn.ModuleList(
50+
TransformerBlock(config) for _ in range(config.n_layers)
51+
)
52+
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
53+
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
54+
55+
# self.freqs_cis: Optional[Tensor] = None
56+
# self.mask_cache: Optional[Tensor] = None
57+
self.max_batch_size = -1
58+
self.max_seq_length = -1
59+
60+
def setup_caches(self, max_batch_size, max_seq_length):
61+
if (
62+
self.max_seq_length >= max_seq_length
63+
and self.max_batch_size >= max_batch_size
64+
):
65+
return
66+
head_dim = self.config.dim // self.config.n_heads
67+
max_seq_length = find_multiple(max_seq_length, 8)
68+
self.max_seq_length = max_seq_length
69+
self.max_batch_size = max_batch_size
70+
for b in self.layers:
71+
b.attention.kv_cache = KVCache(
72+
max_batch_size, max_seq_length, self.config.n_local_heads, head_dim
73+
)
74+
75+
freqs_cis = precompute_freqs_cis(
76+
self.config.dim // self.config.n_heads,
77+
self.config.block_size * 2,
78+
self.config.rope_base,
79+
use_scaled = self.config.use_scaled_rope,
80+
)
81+
self.register_buffer("freqs_cis", freqs_cis, persistent=True)
82+
causal_mask = torch.tril(
83+
torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)
84+
)
85+
self.register_buffer("causal_mask", causal_mask, persistent=True)
86+
87+
def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
88+
assert self.freqs_cis is not None, "Caches must be initialized first"
89+
mask = self.causal_mask[None, None, input_pos]
90+
freqs_cis = self.freqs_cis[input_pos]
91+
x: DTensor = self.tok_embeddings(idx)
92+
# TODO: sequence parallelize this
93+
94+
for _, layer in enumerate(self.layers):
95+
x = layer(x, input_pos, freqs_cis, mask)
96+
x = self.norm(x)
97+
logits = self.output(x)
98+
# print(f"logits shape: {logits.shape}")
99+
return logits
100+
101+
@classmethod
102+
def from_name(cls, name: str):
103+
return cls(TransformerArgs.from_name(name))
104+
105+
@classmethod
106+
def from_table(cls, name: str):
107+
return cls(TransformerArgs.from_table(name))
108+
109+
@classmethod
110+
def from_params(cls, params_path: str):
111+
return cls(TransformerArgs.from_params(params_path))
112+
113+
@classmethod
114+
def from_gguf(cls, gguf_path: str, **kwargs):
115+
from build.gguf_loader import load_model_and_state_dict
116+
117+
model, state_dict = load_model_and_state_dict(gguf_path, **kwargs)
118+
if state_dict != {}:
119+
model.load_state_dict(state_dict, assign=True)
120+
return model
121+
122+
123+
class TransformerBlock(nn.Module):
124+
def __init__(self, config: TransformerArgs) -> None:
125+
super().__init__()
126+
self.attention = Attention(config)
127+
self.feed_forward = FeedForward(config)
128+
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
129+
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
130+
131+
def forward(
132+
self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor
133+
) -> Tensor:
134+
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
135+
out = h + self.feed_forward(self.ffn_norm(h))
136+
return out
137+
138+
139+
class Attention(nn.Module):
140+
def __init__(self, config: TransformerArgs):
141+
super().__init__()
142+
assert config.dim % config.n_heads == 0
143+
144+
# key, query, value projections for all heads, but in a batch
145+
# total_head_dim = (config.n_heads + 2 * config.n_local_heads) * config.head_dim
146+
# self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
147+
wq = nn.Linear(config.dim, config.n_heads * config.head_dim, bias=False)
148+
wk = nn.Linear(
149+
config.dim, config.n_local_heads * config.head_dim, bias=False
150+
)
151+
wv = nn.Linear(
152+
config.dim, config.n_local_heads * config.head_dim, bias=False
153+
)
154+
wo = nn.Linear(config.dim, config.dim, bias=False)
155+
156+
self.wq = parallelize_module(wq, device_mesh, Colwise)
157+
self.wk = parallelize_module(wk, device_mesh, Colwise)
158+
self.wv = parallelize_module(wv, device_mesh, Colwise)
159+
self.wo = parallelize_module(wo, device_mesh, Rowwise)
160+
161+
self.kv_cache = None
162+
163+
self.n_heads = config.n_heads
164+
self.head_dim = config.head_dim
165+
self.n_local_heads = config.n_local_heads
166+
self.dim = config.dim
167+
self._register_load_state_dict_pre_hook(self.load_hook)
168+
169+
def load_hook(self, state_dict, prefix, *args):
170+
# if prefix + "wq.weight" in state_dict:
171+
# wq = state_dict.pop(prefix + "wq.weight")
172+
# wk = state_dict.pop(prefix + "wk.weight")
173+
# wv = state_dict.pop(prefix + "wv.weight")
174+
# state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
175+
176+
if prefix + "wqkv.weight" in state_dict:
177+
wqkv = state_dict.pop(prefix + "wqkv.weight")
178+
q_size = self.n_heads * self.head_dim
179+
kv_size = self.n_local_heads * self.head_dim
180+
wq, wk, wv = torch.split(wqkv, (q_size, kv_size, kv_size), dim=0)
181+
state_dict[prefix + "wq.weight"] = wq
182+
state_dict[prefix + "wk.weight"] = wk
183+
state_dict[prefix + "wv.weight"] = wv
184+
185+
return
186+
187+
def _unfuse_wqkv_state_dict(
188+
state_dict: Dict[str, torch.Tensor],
189+
dim: int,
190+
):
191+
for key in list(state_dict):
192+
if key.endswith("wqkv.weight"):
193+
tensor = state_dict[key]
194+
wq_key = key.replace("wqkv.weight", "wq.weight")
195+
state_dict[wq_key] = tensor[:dim]
196+
wk_key = key.replace("wqkv.weight", "wk.weight")
197+
wv_key = key.replace("wqkv.weight", "wv.weight")
198+
wk, wv = tensor[dim:].chunk(2, 0)
199+
state_dict[wk_key] = wk
200+
state_dict[wv_key] = wv
201+
state_dict.pop(key)
202+
else:
203+
continue
204+
205+
_unfuse_wqkv_state_dict(state_dict, self.dim)
206+
207+
def forward(
208+
self,
209+
x: Tensor,
210+
freqs_cis: Tensor,
211+
mask: Tensor,
212+
input_pos: Optional[Tensor] = None,
213+
) -> Tensor:
214+
bsz, seqlen, _ = x.shape
215+
216+
q: DTensor = self.wq(x)
217+
k: DTensor = self.wk(x)
218+
v: DTensor = self.wv(x)
219+
# We use `to_local()` to convert DTensor back to regular Tensor
220+
q, k, v = q.to_local(), k.to_local(), v.to_local()
221+
# kv_size = self.n_local_heads * self.head_dim
222+
# q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
223+
224+
q = q.view(bsz, seqlen, -1, self.head_dim)
225+
k = k.view(bsz, seqlen, -1, self.head_dim)
226+
v = v.view(bsz, seqlen, -1, self.head_dim)
227+
228+
q = apply_rotary_emb(q, freqs_cis)
229+
k = apply_rotary_emb(k, freqs_cis)
230+
231+
q, k, v = (x.transpose(1, 2) for x in (q, k, v))
232+
233+
# TODO: enable kv cache
234+
#if self.kv_cache is not None:
235+
# k, v = self.kv_cache.update(input_pos, k, v)
236+
237+
k = k.repeat_interleave(self.n_heads // self.n_local_heads, dim=1)
238+
v = v.repeat_interleave(self.n_heads // self.n_local_heads, dim=1)
239+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
240+
241+
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
242+
243+
y: DTensor = self.wo(y)
244+
# TODO: sequence parallelize this
245+
return y.full_tensor()
246+
247+
248+
class FeedForward(nn.Module):
249+
def __init__(self, config: TransformerArgs) -> None:
250+
super().__init__()
251+
w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
252+
w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
253+
w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
254+
self.w1 = parallelize_module(w1, device_mesh, Colwise)
255+
self.w2 = parallelize_module(w2, device_mesh, Rowwise)
256+
self.w3 = parallelize_module(w3, device_mesh, Colwise)
257+
258+
def forward(self, x: Tensor) -> Tensor:
259+
y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x))
260+
# y is a DTensor with Partial placement;
261+
# we convert its placement to Replicate and convert it back to a regular
262+
# Tensor. `full_tensor` is the API that does both.
263+
# TODO: sequence parallelize this
264+
return y.full_tensor()
265+
266+
267+
class RMSNorm(nn.Module):
268+
def __init__(self, dim: int, eps: float = 1e-5):
269+
super().__init__()
270+
self.eps = eps
271+
self.weight = nn.Parameter(torch.ones(dim))
272+
273+
def _norm(self, x):
274+
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
275+
276+
def forward(self, x: Tensor) -> Tensor:
277+
output = self._norm(x.float()).type_as(x)
278+
return output * self.weight

dist_run.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
import torch.distributed as dist
9+
10+
from build.model import TransformerArgs
11+
from build.model_dist import Transformer
12+
13+
# Model config
14+
def main():
15+
config = TransformerArgs.from_name("Transformer-2-7b-chat-hf")
16+
print(config)
17+
18+
# Construct a device mesh with available devices (multi-host or single host)
19+
device_mesh = dist.init_device_mesh("cuda", (2,), mesh_dim_names=("tp",))
20+
rank = dist.get_rank()
21+
device = torch.device(f"cuda:{rank}")
22+
23+
# Create parallel model with device_mesh context
24+
with device:
25+
with device_mesh:
26+
model = Transformer(config)
27+
model.setup_caches(1, 4096)
28+
29+
print(model)
30+
31+
# Distributed run
32+
input_ids = torch.randint(0, config.vocab_size, (1, 4096), device=device)
33+
input_pos = torch.arange(0, 4096, device=device)
34+
output = model(input_ids, input_pos)
35+
dist.destroy_process_group()
36+
print(f"Rank {rank} completes.")
37+
38+
if __name__ == "__main__":
39+
main()

0 commit comments

Comments
 (0)