@@ -684,6 +684,69 @@ def wrapper(*args: TParams.args, **kwargs: TParams.kwargs) -> TReturn:
684684 return wrapper
685685
686686
687+ def broadcast_str (
688+ val : Optional [str ],
689+ src : int = 0 ,
690+ process_group : Optional [dist .ProcessGroup ] = None ,
691+ ) -> Optional [str ]:
692+ """
693+ Broadcasts a string from a source rank to all other ranks in a process group.
694+ Serializes string as sequence of uint8 and broadcasts as a tensor. This avoids
695+ issues with broadcast_object_list and related apis which use pickle to serialize objects.
696+
697+ Args:
698+ val: the string to broadcast
699+ src: the source rank to broadcast from
700+ process_group: the process group to broadcast in. Defaults to the WORLD process group.
701+
702+ Returns:
703+ The broadcasted string.
704+
705+ Note:
706+ This function issues two collective calls, one to broadcast the size of the serialized string and
707+ one to broadcast the string itself. This can theoretically be limited to one collective call
708+ by hardcoding maximum buffer size to use, and filling unused buffer slots with preselected
709+ null tokens. However, this is not implemented to avoid unnecessary complexity.
710+ """
711+ if not dist .is_available () or not dist .is_initialized ():
712+ return val
713+
714+ rank = dist .get_rank (group = process_group )
715+
716+ # device to use when broadcasting the string
717+ device = torch .device (
718+ torch .cuda .current_device ()
719+ if dist .get_backend (process_group ) == "nccl"
720+ else "cpu"
721+ )
722+
723+ # dummy instantiation to keep pyre happy
724+ buffer = torch .empty ((1 ), dtype = torch .uint8 )
725+ buffer_length = torch .empty ((1 ), dtype = torch .int , device = device )
726+ if rank == src :
727+ assert (
728+ val is not None
729+ ), "Source rank must provide a string to broadcast, got None"
730+
731+ # convert string to tensor
732+ buffer = torch .frombuffer (val .encode ("utf-8" ), dtype = torch .uint8 )
733+ buffer = buffer .to (device = device )
734+ buffer_length = torch .tensor ((len (buffer )), dtype = torch .int , device = device )
735+
736+ # first broadcast the buffer length so receiving ranks can allocate the correct amount of memory
737+ dist .broadcast (buffer_length , src = src , group = process_group )
738+ if rank != src :
739+ size = int (buffer_length .item ())
740+ buffer = torch .empty ((size ), dtype = torch .uint8 , device = device )
741+
742+ # now broadcast string
743+ dist .broadcast (buffer , src = src , group = process_group )
744+
745+ # convert tensor to string
746+ string = bytes (buffer .tolist ()).decode (encoding = "utf-8" )
747+ return string
748+
749+
687750@contextmanager
688751def get_or_create_gloo_pg (
689752 candidate_pg : Optional [dist .ProcessGroup ] = None ,
0 commit comments