@@ -116,6 +116,26 @@ def broadcast(self, obj, root=0):
116116 def allgather (self , obj , root = 0 ):
117117 pass
118118
119+ @abstractmethod
120+ def tp_broadcast (self , obj , root = 0 , ** kwargs ):
121+ pass
122+
123+ @abstractmethod
124+ def cp_broadcast (self , obj , root = 0 , ** kwargs ):
125+ pass
126+
127+ def tp_cp_broadcast (self , obj , root = 0 , ** kwargs ):
128+ """Broadcast object across both TP and CP groups.
129+
130+ This is used when both TP and CP parallelism are enabled (e.g., helix parallelism).
131+ First broadcasts within the TP group, then within the CP group.
132+ """
133+ if self .tp_size > 1 :
134+ obj = self .tp_broadcast (obj , root = root , ** kwargs )
135+ if self .cp_size > 1 :
136+ obj = self .cp_broadcast (obj , root = root , ** kwargs )
137+ return obj
138+
119139
120140def safe_broadcast (comm , obj , root = 0 , chunk_size : int = 4 * 1024 * 1024 ):
121141 """
@@ -407,14 +427,26 @@ def create_cp_comm(self):
407427 def cp_allgather (self , obj ):
408428 return self .cp_comm .allgather (obj )
409429
430+ def cp_broadcast (self ,
431+ obj ,
432+ root = 0 ,
433+ chunk_size : int = 4 * 1024 * 1024 ,
434+ ** kwargs ):
435+ comm = self .cp_comm
436+ return safe_broadcast (comm , obj , root = root , chunk_size = chunk_size )
437+
410438 def tp_allgather (self , obj ):
411439 return self .tp_comm .allgather (obj )
412440
413441 def tp_gather (self , obj , root = 0 , chunk_size : int = 4 * 1024 * 1024 ):
414442 comm = self .tp_comm
415443 return safe_gather (comm , obj , root = root , chunk_size = chunk_size )
416444
417- def tp_broadcast (self , obj , root = 0 , chunk_size : int = 4 * 1024 * 1024 ):
445+ def tp_broadcast (self ,
446+ obj ,
447+ root = 0 ,
448+ chunk_size : int = 4 * 1024 * 1024 ,
449+ ** kwargs ):
418450 comm = self .tp_comm
419451 return safe_broadcast (comm , obj , root = root , chunk_size = chunk_size )
420452
@@ -699,7 +731,7 @@ def tp_gather(self, obj, dst=0):
699731 return output_list
700732
701733 @log_op
702- def tp_broadcast (self , obj , root = 0 ):
734+ def tp_broadcast (self , obj , root = 0 , ** kwargs ):
703735 if isinstance (obj , torch .Tensor ):
704736 dist .broadcast (obj , src = root , group = self .mapping .tp_group_pg )
705737 return obj
@@ -712,6 +744,20 @@ def tp_broadcast(self, obj, root=0):
712744 device = torch .device ("cpu" ))
713745 return ret [0 ]
714746
747+ @log_op
748+ def cp_broadcast (self , obj , root = 0 , ** kwargs ):
749+ if isinstance (obj , torch .Tensor ):
750+ dist .broadcast (obj , src = root , group = self .mapping .cp_group_pg )
751+ return obj
752+ else :
753+ ret = [obj ]
754+ torch .distributed .broadcast_object_list (
755+ ret ,
756+ src = root ,
757+ group = self .mapping .cp_group_pg ,
758+ device = torch .device ("cpu" ))
759+ return ret [0 ]
760+
715761 @log_op
716762 def pp_allgather (self , obj ):
717763 if isinstance (obj , torch .Tensor ):
0 commit comments