1818import torch .nn as nn
1919
2020from torch import Tensor
21- from torch .distributed ._tensor import DTensor , Replicate , Shard
21+ from torch .distributed ._tensor import DTensor , Replicate
2222from torch .distributed .device_mesh import DeviceMesh
2323from torch .distributed .tensor .parallel import (
2424 ColwiseParallel ,
@@ -428,8 +428,6 @@ def __init__(self, config: TransformerArgs) -> None:
428428
429429 self .max_batch_size = - 1
430430 self .max_seq_length = - 1
431- # For supporting sequence parallel (default is off, thus value of 1)
432- self .seq_parallel_degree = 1
433431
434432 def setup_caches (self , max_batch_size , max_seq_length ):
435433 if (
@@ -465,30 +463,19 @@ def distribute(self, device_mesh: DeviceMesh):
465463 parallelize_module (
466464 self .tok_embeddings ,
467465 device_mesh ,
468- RowwiseParallel (
469- input_layouts = Replicate (),
470- output_layouts = Shard (1 ),
471- ),
466+ RowwiseParallel (input_layouts = Replicate ()),
472467 )
473468
474469 for layer in self .layers .values ():
475470 layer .distribute (device_mesh )
476471
477- if self .norm :
478- parallelize_module (self .norm , device_mesh , SequenceParallel ())
479-
480472 if self .output :
481473 parallelize_module (
482474 self .output ,
483475 device_mesh ,
484- ColwiseParallel (
485- input_layouts = Shard (1 ),
486- output_layouts = Replicate (),
487- ),
476+ ColwiseParallel (output_layouts = Replicate ()),
488477 )
489478
490- self .seq_parallel_degree = device_mesh .size ()
491-
492479 # This is a temporary solution to pass input_pos to non-0 pipeline stages
493480 # TODO: make `step()` function of dist.pipelining accept args for non-0 stages
494481 def setup_input_pos (self , input_pos : Tensor ) -> None :
@@ -525,8 +512,6 @@ def __init__(self, config: TransformerArgs) -> None:
525512 def distribute (self , device_mesh : DeviceMesh ):
526513 self .attention .distribute (device_mesh )
527514 self .feed_forward .distribute (device_mesh )
528- parallelize_module (self .ffn_norm , device_mesh , SequenceParallel ())
529- parallelize_module (self .attention_norm , device_mesh , SequenceParallel ())
530515
531516 def forward (
532517 self , x : Tensor , input_pos : Tensor , freqs_cis : Tensor , mask : Tensor
@@ -610,14 +595,11 @@ def _unfuse_wqkv_state_dict(
610595 _unfuse_wqkv_state_dict (state_dict , self .dim )
611596
612597 def distribute (self , device_mesh : DeviceMesh ):
613- self .device_mesh = device_mesh
614598 self .tp_degree = device_mesh .size ()
615599 parallelize_module (self .wq , device_mesh , ColwiseParallel ())
616600 parallelize_module (self .wk , device_mesh , ColwiseParallel ())
617601 parallelize_module (self .wv , device_mesh , ColwiseParallel ())
618- parallelize_module (
619- self .wo , device_mesh , RowwiseParallel (output_layouts = Shard (1 ))
620- )
602+ parallelize_module (self .wo , device_mesh , RowwiseParallel ())
621603
622604 def forward (
623605 self ,
@@ -626,10 +608,6 @@ def forward(
626608 mask : Tensor ,
627609 input_pos : Optional [Tensor ] = None ,
628610 ) -> Tensor :
629- # Gather sequence back in case of sequence parallelism before attention
630- if isinstance (x , DTensor ):
631- x = x .redistribute (self .device_mesh , [Replicate ()])
632-
633611 bsz , seqlen , _ = x .shape
634612
635613 q = self .wq (x )
@@ -675,18 +653,11 @@ def __init__(self, config: TransformerArgs) -> None:
675653 self .w3 = nn .Linear (config .dim , config .hidden_dim , bias = False )
676654
677655 def distribute (self , device_mesh : DeviceMesh ):
678- self .device_mesh = device_mesh
679656 parallelize_module (self .w1 , device_mesh , ColwiseParallel ())
680- parallelize_module (
681- self .w2 , device_mesh , RowwiseParallel (output_layouts = Shard (1 ))
682- )
657+ parallelize_module (self .w2 , device_mesh , RowwiseParallel ())
683658 parallelize_module (self .w3 , device_mesh , ColwiseParallel ())
684659
685660 def forward (self , x : Tensor ) -> Tensor :
686- # Gather sequence back in case of sequence parallelism
687- if isinstance (x , DTensor ):
688- x = x .redistribute (self .device_mesh , [Replicate ()])
689-
690661 return self .w2 (F .silu (self .w1 (x )) * self .w3 (x ))
691662
692663
0 commit comments