diff --git a/dist_run.py b/dist_run.py index 6c87c036f..8bcbd5e4d 100644 --- a/dist_run.py +++ b/dist_run.py @@ -258,12 +258,11 @@ def main(args): assert world_size % pp_degree == 0 assert config.n_layers % pp_degree == 0 - # Sequence parallel is enabled in this program - # Sequence parallel = Tensor parallel + dividing sequence by tp_degree at layer boundary - sp_degree = world_size // pp_degree + # Tensor parallel is enabled in this program + tp_degree = world_size // pp_degree # Create device mesh - mesh_dimensions = (pp_degree, sp_degree) + mesh_dimensions = (pp_degree, tp_degree) device_mesh = _create_device_mesh(mesh_dimensions) tp_mesh = device_mesh["tp"] pp_mesh = device_mesh["pp"] @@ -299,7 +298,6 @@ def main(args): seqlen = 4096 # sequence length dim = 4096 # embedding dimension - assert seqlen % sp_degree == 0 # Setup KV caches (after model distribution) # TODO: the setting below only works for 1 micro-batch case. To support @@ -309,7 +307,7 @@ def main(args): mb_ids = torch.randint(0, config.vocab_size, (mb_size, seqlen), device=device) activation = torch.rand( - mb_size, seqlen // sp_degree, dim, device=device, dtype=model_dtype + mb_size, seqlen, dim, device=device, dtype=model_dtype ) example_args = mb_ids if pp_rank == 0 else activation diff --git a/torchchat/model.py b/torchchat/model.py index a576d5036..aaa72cb2a 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -21,7 +21,7 @@ import torch.nn as nn from torch import Tensor -from torch.distributed._tensor import DTensor, Replicate, Shard +from torch.distributed._tensor import DTensor, Replicate from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor.parallel import ( ColwiseParallel, @@ -605,8 +605,6 @@ def __init__(self, config: TransformerArgs) -> None: self.max_batch_size = -1 self.max_seq_length = -1 - # For supporting sequence parallel (default is off, thus value of 1) - self.seq_parallel_degree = 1 def setup_caches(self, max_batch_size, max_seq_length): if ( @@ -642,30 +640,19 @@ def distribute(self, device_mesh: DeviceMesh): parallelize_module( self.tok_embeddings, device_mesh, - RowwiseParallel( - input_layouts=Replicate(), - output_layouts=Shard(1), - ), + RowwiseParallel(input_layouts=Replicate()), ) for layer in self.layers.values(): layer.distribute(device_mesh) - if self.norm: - parallelize_module(self.norm, device_mesh, SequenceParallel()) - if self.output: parallelize_module( self.output, device_mesh, - ColwiseParallel( - input_layouts=Shard(1), - output_layouts=Replicate(), - ), + ColwiseParallel(output_layouts=Replicate()), ) - self.seq_parallel_degree = device_mesh.size() - # This is a temporary solution to pass input_pos to non-0 pipeline stages # TODO: make `step()` function of dist.pipelining accept args for non-0 stages def setup_input_pos(self, input_pos: Tensor) -> None: @@ -702,8 +689,6 @@ def __init__(self, config: TransformerArgs) -> None: def distribute(self, device_mesh: DeviceMesh): self.attention.distribute(device_mesh) self.feed_forward.distribute(device_mesh) - parallelize_module(self.ffn_norm, device_mesh, SequenceParallel()) - parallelize_module(self.attention_norm, device_mesh, SequenceParallel()) def forward( self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor @@ -787,14 +772,11 @@ def _unfuse_wqkv_state_dict( _unfuse_wqkv_state_dict(state_dict, self.dim) def distribute(self, device_mesh: DeviceMesh): - self.device_mesh = device_mesh self.tp_degree = device_mesh.size() parallelize_module(self.wq, device_mesh, ColwiseParallel()) parallelize_module(self.wk, device_mesh, ColwiseParallel()) parallelize_module(self.wv, device_mesh, ColwiseParallel()) - parallelize_module( - self.wo, device_mesh, RowwiseParallel(output_layouts=Shard(1)) - ) + parallelize_module(self.wo, device_mesh, RowwiseParallel()) def forward( self, @@ -803,10 +785,6 @@ def forward( mask: Tensor, input_pos: Optional[Tensor] = None, ) -> Tensor: - # Gather sequence back in case of sequence parallelism before attention - if isinstance(x, DTensor): - x = x.redistribute(self.device_mesh, [Replicate()]) - bsz, seqlen, _ = x.shape q = self.wq(x) @@ -852,18 +830,11 @@ def __init__(self, config: TransformerArgs) -> None: self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False) def distribute(self, device_mesh: DeviceMesh): - self.device_mesh = device_mesh parallelize_module(self.w1, device_mesh, ColwiseParallel()) - parallelize_module( - self.w2, device_mesh, RowwiseParallel(output_layouts=Shard(1)) - ) + parallelize_module(self.w2, device_mesh, RowwiseParallel()) parallelize_module(self.w3, device_mesh, ColwiseParallel()) def forward(self, x: Tensor) -> Tensor: - # Gather sequence back in case of sequence parallelism - if isinstance(x, DTensor): - x = x.redistribute(self.device_mesh, [Replicate()]) - return self.w2(F.silu(self.w1(x)) * self.w3(x))