Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit e27e162

Browse files
authored
[Distributed] Use Tensor Parallel instead of Sequence Parallel (#1160)
1 parent 4774eaf commit e27e162

File tree

2 files changed

+9
-40
lines changed

2 files changed

+9
-40
lines changed

dist_run.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -258,12 +258,11 @@ def main(args):
258258
assert world_size % pp_degree == 0
259259
assert config.n_layers % pp_degree == 0
260260

261-
# Sequence parallel is enabled in this program
262-
# Sequence parallel = Tensor parallel + dividing sequence by tp_degree at layer boundary
263-
sp_degree = world_size // pp_degree
261+
# Tensor parallel is enabled in this program
262+
tp_degree = world_size // pp_degree
264263

265264
# Create device mesh
266-
mesh_dimensions = (pp_degree, sp_degree)
265+
mesh_dimensions = (pp_degree, tp_degree)
267266
device_mesh = _create_device_mesh(mesh_dimensions)
268267
tp_mesh = device_mesh["tp"]
269268
pp_mesh = device_mesh["pp"]
@@ -299,7 +298,6 @@ def main(args):
299298

300299
seqlen = 4096 # sequence length
301300
dim = 4096 # embedding dimension
302-
assert seqlen % sp_degree == 0
303301

304302
# Setup KV caches (after model distribution)
305303
# TODO: the setting below only works for 1 micro-batch case. To support
@@ -309,7 +307,7 @@ def main(args):
309307

310308
mb_ids = torch.randint(0, config.vocab_size, (mb_size, seqlen), device=device)
311309
activation = torch.rand(
312-
mb_size, seqlen // sp_degree, dim, device=device, dtype=model_dtype
310+
mb_size, seqlen, dim, device=device, dtype=model_dtype
313311
)
314312
example_args = mb_ids if pp_rank == 0 else activation
315313

torchchat/model.py

Lines changed: 5 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import torch.nn as nn
2222

2323
from torch import Tensor
24-
from torch.distributed._tensor import DTensor, Replicate, Shard
24+
from torch.distributed._tensor import DTensor, Replicate
2525
from torch.distributed.device_mesh import DeviceMesh
2626
from torch.distributed.tensor.parallel import (
2727
ColwiseParallel,
@@ -605,8 +605,6 @@ def __init__(self, config: TransformerArgs) -> None:
605605

606606
self.max_batch_size = -1
607607
self.max_seq_length = -1
608-
# For supporting sequence parallel (default is off, thus value of 1)
609-
self.seq_parallel_degree = 1
610608

611609
def setup_caches(self, max_batch_size, max_seq_length):
612610
if (
@@ -642,30 +640,19 @@ def distribute(self, device_mesh: DeviceMesh):
642640
parallelize_module(
643641
self.tok_embeddings,
644642
device_mesh,
645-
RowwiseParallel(
646-
input_layouts=Replicate(),
647-
output_layouts=Shard(1),
648-
),
643+
RowwiseParallel(input_layouts=Replicate()),
649644
)
650645

651646
for layer in self.layers.values():
652647
layer.distribute(device_mesh)
653648

654-
if self.norm:
655-
parallelize_module(self.norm, device_mesh, SequenceParallel())
656-
657649
if self.output:
658650
parallelize_module(
659651
self.output,
660652
device_mesh,
661-
ColwiseParallel(
662-
input_layouts=Shard(1),
663-
output_layouts=Replicate(),
664-
),
653+
ColwiseParallel(output_layouts=Replicate()),
665654
)
666655

667-
self.seq_parallel_degree = device_mesh.size()
668-
669656
# This is a temporary solution to pass input_pos to non-0 pipeline stages
670657
# TODO: make `step()` function of dist.pipelining accept args for non-0 stages
671658
def setup_input_pos(self, input_pos: Tensor) -> None:
@@ -702,8 +689,6 @@ def __init__(self, config: TransformerArgs) -> None:
702689
def distribute(self, device_mesh: DeviceMesh):
703690
self.attention.distribute(device_mesh)
704691
self.feed_forward.distribute(device_mesh)
705-
parallelize_module(self.ffn_norm, device_mesh, SequenceParallel())
706-
parallelize_module(self.attention_norm, device_mesh, SequenceParallel())
707692

708693
def forward(
709694
self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor
@@ -787,14 +772,11 @@ def _unfuse_wqkv_state_dict(
787772
_unfuse_wqkv_state_dict(state_dict, self.dim)
788773

789774
def distribute(self, device_mesh: DeviceMesh):
790-
self.device_mesh = device_mesh
791775
self.tp_degree = device_mesh.size()
792776
parallelize_module(self.wq, device_mesh, ColwiseParallel())
793777
parallelize_module(self.wk, device_mesh, ColwiseParallel())
794778
parallelize_module(self.wv, device_mesh, ColwiseParallel())
795-
parallelize_module(
796-
self.wo, device_mesh, RowwiseParallel(output_layouts=Shard(1))
797-
)
779+
parallelize_module(self.wo, device_mesh, RowwiseParallel())
798780

799781
def forward(
800782
self,
@@ -803,10 +785,6 @@ def forward(
803785
mask: Tensor,
804786
input_pos: Optional[Tensor] = None,
805787
) -> Tensor:
806-
# Gather sequence back in case of sequence parallelism before attention
807-
if isinstance(x, DTensor):
808-
x = x.redistribute(self.device_mesh, [Replicate()])
809-
810788
bsz, seqlen, _ = x.shape
811789

812790
q = self.wq(x)
@@ -852,18 +830,11 @@ def __init__(self, config: TransformerArgs) -> None:
852830
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
853831

854832
def distribute(self, device_mesh: DeviceMesh):
855-
self.device_mesh = device_mesh
856833
parallelize_module(self.w1, device_mesh, ColwiseParallel())
857-
parallelize_module(
858-
self.w2, device_mesh, RowwiseParallel(output_layouts=Shard(1))
859-
)
834+
parallelize_module(self.w2, device_mesh, RowwiseParallel())
860835
parallelize_module(self.w3, device_mesh, ColwiseParallel())
861836

862837
def forward(self, x: Tensor) -> Tensor:
863-
# Gather sequence back in case of sequence parallelism
864-
if isinstance(x, DTensor):
865-
x = x.redistribute(self.device_mesh, [Replicate()])
866-
867838
return self.w2(F.silu(self.w1(x)) * self.w3(x))
868839

869840

0 commit comments

Comments
 (0)