@@ -684,6 +684,69 @@ def wrapper(*args: TParams.args, **kwargs: TParams.kwargs) -> TReturn:
684
684
return wrapper
685
685
686
686
687
+ def broadcast_str (
688
+ val : Optional [str ],
689
+ src : int = 0 ,
690
+ process_group : Optional [dist .ProcessGroup ] = None ,
691
+ ) -> Optional [str ]:
692
+ """
693
+ Broadcasts a string from a source rank to all other ranks in a process group.
694
+ Serializes string as sequence of uint8 and broadcasts as a tensor. This avoids
695
+ issues with broadcast_object_list and related apis which use pickle to serialize objects.
696
+
697
+ Args:
698
+ val: the string to broadcast
699
+ src: the source rank to broadcast from
700
+ process_group: the process group to broadcast in. Defaults to the WORLD process group.
701
+
702
+ Returns:
703
+ The broadcasted string.
704
+
705
+ Note:
706
+ This function issues two collective calls, one to broadcast the size of the serialized string and
707
+ one to broadcast the string itself. This can theoretically be limited to one collective call
708
+ by hardcoding maximum buffer size to use, and filling unused buffer slots with preselected
709
+ null tokens. However, this is not implemented to avoid unnecessary complexity.
710
+ """
711
+ if not dist .is_available () or not dist .is_initialized ():
712
+ return val
713
+
714
+ rank = dist .get_rank (group = process_group )
715
+
716
+ # device to use when broadcasting the string
717
+ device = torch .device (
718
+ torch .cuda .current_device ()
719
+ if dist .get_backend (process_group ) == "nccl"
720
+ else "cpu"
721
+ )
722
+
723
+ # dummy instantiation to keep pyre happy
724
+ buffer = torch .empty ((1 ), dtype = torch .uint8 )
725
+ buffer_length = torch .empty ((1 ), dtype = torch .int , device = device )
726
+ if rank == src :
727
+ assert (
728
+ val is not None
729
+ ), "Source rank must provide a string to broadcast, got None"
730
+
731
+ # convert string to tensor
732
+ buffer = torch .frombuffer (val .encode ("utf-8" ), dtype = torch .uint8 )
733
+ buffer = buffer .to (device = device )
734
+ buffer_length = torch .tensor ((len (buffer )), dtype = torch .int , device = device )
735
+
736
+ # first broadcast the buffer length so receiving ranks can allocate the correct amount of memory
737
+ dist .broadcast (buffer_length , src = src , group = process_group )
738
+ if rank != src :
739
+ size = int (buffer_length .item ())
740
+ buffer = torch .empty ((size ), dtype = torch .uint8 , device = device )
741
+
742
+ # now broadcast string
743
+ dist .broadcast (buffer , src = src , group = process_group )
744
+
745
+ # convert tensor to string
746
+ string = bytes (buffer .tolist ()).decode (encoding = "utf-8" )
747
+ return string
748
+
749
+
687
750
@contextmanager
688
751
def get_or_create_gloo_pg (
689
752
candidate_pg : Optional [dist .ProcessGroup ] = None ,
0 commit comments