2121import  torch .nn  as  nn 
2222
2323from  torch  import  Tensor 
24- from  torch .distributed ._tensor  import  DTensor , Replicate ,  Shard 
24+ from  torch .distributed ._tensor  import  DTensor , Replicate 
2525from  torch .distributed .device_mesh  import  DeviceMesh 
2626from  torch .distributed .tensor .parallel  import  (
2727    ColwiseParallel ,
@@ -605,8 +605,6 @@ def __init__(self, config: TransformerArgs) -> None:
605605
606606        self .max_batch_size  =  - 1 
607607        self .max_seq_length  =  - 1 
608-         # For supporting sequence parallel (default is off, thus value of 1) 
609-         self .seq_parallel_degree  =  1 
610608
611609    def  setup_caches (self , max_batch_size , max_seq_length ):
612610        if  (
@@ -642,30 +640,19 @@ def distribute(self, device_mesh: DeviceMesh):
642640            parallelize_module (
643641                self .tok_embeddings ,
644642                device_mesh ,
645-                 RowwiseParallel (
646-                     input_layouts = Replicate (),
647-                     output_layouts = Shard (1 ),
648-                 ),
643+                 RowwiseParallel (input_layouts = Replicate ()),
649644            )
650645
651646        for  layer  in  self .layers .values ():
652647            layer .distribute (device_mesh )
653648
654-         if  self .norm :
655-             parallelize_module (self .norm , device_mesh , SequenceParallel ())
656- 
657649        if  self .output :
658650            parallelize_module (
659651                self .output ,
660652                device_mesh ,
661-                 ColwiseParallel (
662-                     input_layouts = Shard (1 ),
663-                     output_layouts = Replicate (),
664-                 ),
653+                 ColwiseParallel (output_layouts = Replicate ()),
665654            )
666655
667-         self .seq_parallel_degree  =  device_mesh .size ()
668- 
669656    # This is a temporary solution to pass input_pos to non-0 pipeline stages 
670657    # TODO: make `step()` function of dist.pipelining accept args for non-0 stages 
671658    def  setup_input_pos (self , input_pos : Tensor ) ->  None :
@@ -702,8 +689,6 @@ def __init__(self, config: TransformerArgs) -> None:
702689    def  distribute (self , device_mesh : DeviceMesh ):
703690        self .attention .distribute (device_mesh )
704691        self .feed_forward .distribute (device_mesh )
705-         parallelize_module (self .ffn_norm , device_mesh , SequenceParallel ())
706-         parallelize_module (self .attention_norm , device_mesh , SequenceParallel ())
707692
708693    def  forward (
709694        self , x : Tensor , input_pos : Tensor , freqs_cis : Tensor , mask : Tensor 
@@ -787,14 +772,11 @@ def _unfuse_wqkv_state_dict(
787772        _unfuse_wqkv_state_dict (state_dict , self .dim )
788773
789774    def  distribute (self , device_mesh : DeviceMesh ):
790-         self .device_mesh  =  device_mesh 
791775        self .tp_degree  =  device_mesh .size ()
792776        parallelize_module (self .wq , device_mesh , ColwiseParallel ())
793777        parallelize_module (self .wk , device_mesh , ColwiseParallel ())
794778        parallelize_module (self .wv , device_mesh , ColwiseParallel ())
795-         parallelize_module (
796-             self .wo , device_mesh , RowwiseParallel (output_layouts = Shard (1 ))
797-         )
779+         parallelize_module (self .wo , device_mesh , RowwiseParallel ())
798780
799781    def  forward (
800782        self ,
@@ -803,10 +785,6 @@ def forward(
803785        mask : Tensor ,
804786        input_pos : Optional [Tensor ] =  None ,
805787    ) ->  Tensor :
806-         # Gather sequence back in case of sequence parallelism before attention 
807-         if  isinstance (x , DTensor ):
808-             x  =  x .redistribute (self .device_mesh , [Replicate ()])
809- 
810788        bsz , seqlen , _  =  x .shape 
811789
812790        q  =  self .wq (x )
@@ -852,18 +830,11 @@ def __init__(self, config: TransformerArgs) -> None:
852830        self .w3  =  nn .Linear (config .dim , config .hidden_dim , bias = False )
853831
854832    def  distribute (self , device_mesh : DeviceMesh ):
855-         self .device_mesh  =  device_mesh 
856833        parallelize_module (self .w1 , device_mesh , ColwiseParallel ())
857-         parallelize_module (
858-             self .w2 , device_mesh , RowwiseParallel (output_layouts = Shard (1 ))
859-         )
834+         parallelize_module (self .w2 , device_mesh , RowwiseParallel ())
860835        parallelize_module (self .w3 , device_mesh , ColwiseParallel ())
861836
862837    def  forward (self , x : Tensor ) ->  Tensor :
863-         # Gather sequence back in case of sequence parallelism 
864-         if  isinstance (x , DTensor ):
865-             x  =  x .redistribute (self .device_mesh , [Replicate ()])
866- 
867838        return  self .w2 (F .silu (self .w1 (x )) *  self .w3 (x ))
868839
869840
0 commit comments