Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 976f939

Browse files
committed
[Distributed] Use Tensor Parallel instead of Sequence Parallel
1 parent 0bc64c2 commit 976f939

File tree

2 files changed

+9
-40
lines changed

2 files changed

+9
-40
lines changed

dist_run.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -255,12 +255,11 @@ def main(args):
255255
assert world_size % pp_degree == 0
256256
assert config.n_layers % pp_degree == 0
257257

258-
# Sequence parallel is enabled in this program
259-
# Sequence parallel = Tensor parallel + dividing sequence by tp_degree at layer boundary
260-
sp_degree = world_size // pp_degree
258+
# Tensor parallel is enabled in this program
259+
tp_degree = world_size // pp_degree
261260

262261
# Create device mesh
263-
mesh_dimensions = (pp_degree, sp_degree)
262+
mesh_dimensions = (pp_degree, tp_degree)
264263
device_mesh = _create_device_mesh(mesh_dimensions)
265264
tp_mesh = device_mesh["tp"]
266265
pp_mesh = device_mesh["pp"]
@@ -295,7 +294,6 @@ def main(args):
295294

296295
seqlen = 4096 # sequence length
297296
dim = 4096 # embedding dimension
298-
assert seqlen % sp_degree == 0
299297

300298
# Setup KV caches (after model distribution)
301299
# TODO: the setting below only works for 1 micro-batch case. To support
@@ -305,7 +303,7 @@ def main(args):
305303

306304
mb_ids = torch.randint(0, config.vocab_size, (mb_size, seqlen), device=device)
307305
activation = torch.rand(
308-
mb_size, seqlen // sp_degree, dim, device=device, dtype=model_dtype
306+
mb_size, seqlen, dim, device=device, dtype=model_dtype
309307
)
310308
example_args = mb_ids if pp_rank == 0 else activation
311309

torchchat/model.py

Lines changed: 5 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch.nn as nn
1919

2020
from torch import Tensor
21-
from torch.distributed._tensor import DTensor, Replicate, Shard
21+
from torch.distributed._tensor import DTensor, Replicate
2222
from torch.distributed.device_mesh import DeviceMesh
2323
from torch.distributed.tensor.parallel import (
2424
ColwiseParallel,
@@ -428,8 +428,6 @@ def __init__(self, config: TransformerArgs) -> None:
428428

429429
self.max_batch_size = -1
430430
self.max_seq_length = -1
431-
# For supporting sequence parallel (default is off, thus value of 1)
432-
self.seq_parallel_degree = 1
433431

434432
def setup_caches(self, max_batch_size, max_seq_length):
435433
if (
@@ -465,30 +463,19 @@ def distribute(self, device_mesh: DeviceMesh):
465463
parallelize_module(
466464
self.tok_embeddings,
467465
device_mesh,
468-
RowwiseParallel(
469-
input_layouts=Replicate(),
470-
output_layouts=Shard(1),
471-
),
466+
RowwiseParallel(input_layouts=Replicate()),
472467
)
473468

474469
for layer in self.layers.values():
475470
layer.distribute(device_mesh)
476471

477-
if self.norm:
478-
parallelize_module(self.norm, device_mesh, SequenceParallel())
479-
480472
if self.output:
481473
parallelize_module(
482474
self.output,
483475
device_mesh,
484-
ColwiseParallel(
485-
input_layouts=Shard(1),
486-
output_layouts=Replicate(),
487-
),
476+
ColwiseParallel(output_layouts=Replicate()),
488477
)
489478

490-
self.seq_parallel_degree = device_mesh.size()
491-
492479
# This is a temporary solution to pass input_pos to non-0 pipeline stages
493480
# TODO: make `step()` function of dist.pipelining accept args for non-0 stages
494481
def setup_input_pos(self, input_pos: Tensor) -> None:
@@ -525,8 +512,6 @@ def __init__(self, config: TransformerArgs) -> None:
525512
def distribute(self, device_mesh: DeviceMesh):
526513
self.attention.distribute(device_mesh)
527514
self.feed_forward.distribute(device_mesh)
528-
parallelize_module(self.ffn_norm, device_mesh, SequenceParallel())
529-
parallelize_module(self.attention_norm, device_mesh, SequenceParallel())
530515

531516
def forward(
532517
self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor
@@ -610,14 +595,11 @@ def _unfuse_wqkv_state_dict(
610595
_unfuse_wqkv_state_dict(state_dict, self.dim)
611596

612597
def distribute(self, device_mesh: DeviceMesh):
613-
self.device_mesh = device_mesh
614598
self.tp_degree = device_mesh.size()
615599
parallelize_module(self.wq, device_mesh, ColwiseParallel())
616600
parallelize_module(self.wk, device_mesh, ColwiseParallel())
617601
parallelize_module(self.wv, device_mesh, ColwiseParallel())
618-
parallelize_module(
619-
self.wo, device_mesh, RowwiseParallel(output_layouts=Shard(1))
620-
)
602+
parallelize_module(self.wo, device_mesh, RowwiseParallel())
621603

622604
def forward(
623605
self,
@@ -626,10 +608,6 @@ def forward(
626608
mask: Tensor,
627609
input_pos: Optional[Tensor] = None,
628610
) -> Tensor:
629-
# Gather sequence back in case of sequence parallelism before attention
630-
if isinstance(x, DTensor):
631-
x = x.redistribute(self.device_mesh, [Replicate()])
632-
633611
bsz, seqlen, _ = x.shape
634612

635613
q = self.wq(x)
@@ -675,18 +653,11 @@ def __init__(self, config: TransformerArgs) -> None:
675653
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
676654

677655
def distribute(self, device_mesh: DeviceMesh):
678-
self.device_mesh = device_mesh
679656
parallelize_module(self.w1, device_mesh, ColwiseParallel())
680-
parallelize_module(
681-
self.w2, device_mesh, RowwiseParallel(output_layouts=Shard(1))
682-
)
657+
parallelize_module(self.w2, device_mesh, RowwiseParallel())
683658
parallelize_module(self.w3, device_mesh, ColwiseParallel())
684659

685660
def forward(self, x: Tensor) -> Tensor:
686-
# Gather sequence back in case of sequence parallelism
687-
if isinstance(x, DTensor):
688-
x = x.redistribute(self.device_mesh, [Replicate()])
689-
690661
return self.w2(F.silu(self.w1(x)) * self.w3(x))
691662

692663

0 commit comments

Comments
 (0)