File tree Expand file tree Collapse file tree 2 files changed +67
-0
lines changed Expand file tree Collapse file tree 2 files changed +67
-0
lines changed Original file line number Diff line number Diff line change 20
20
from torch .distributed import ProcessGroup
21
21
from torchtnt .utils .distributed import (
22
22
_validate_global_rank_world_size ,
23
+ all_gather_str ,
23
24
all_gather_tensors ,
24
25
broadcast_str ,
25
26
destroy_process_group ,
@@ -610,3 +611,29 @@ def _test_broadcast_str() -> None:
610
611
611
612
tc = unittest .TestCase ()
612
613
tc .assertEqual (broadcasted_val , "foo" )
614
+
615
+ @skip_if_not_distributed
616
+ def test_all_gather_str (self ) -> None :
617
+ backend = "gloo"
618
+ if torch .cuda .is_available ():
619
+ backend = "nccl"
620
+
621
+ spawn_multi_process (2 , backend , self ._test_all_gather_str )
622
+
623
+ @staticmethod
624
+ def _test_all_gather_str () -> None :
625
+ if torch .cuda .is_available ():
626
+ torch .cuda .set_device (dist .get_rank ())
627
+
628
+ val = None
629
+ if dist .get_rank () == 0 :
630
+ val = "foo"
631
+ else :
632
+ val = "barzoo"
633
+
634
+ # Test case 1: fixed_buffer_size == len(val)
635
+ vals = all_gather_str (val )
636
+
637
+ tc = unittest .TestCase ()
638
+ tc .assertEqual (vals [0 ], "foo" )
639
+ tc .assertEqual (vals [1 ], "barzoo" )
Original file line number Diff line number Diff line change @@ -772,6 +772,46 @@ def broadcast_str(
772
772
return string
773
773
774
774
775
+ def all_gather_str (
776
+ val : str , process_group : Optional [dist .ProcessGroup ] = None
777
+ ) -> List [str ]:
778
+ """
779
+ Optimized all gather-ing string without invoking all_gather_object
780
+ which is subject to hang issues on nccl.
781
+
782
+ Args:
783
+ val: string to include in all_gather
784
+ process_group: the process group to broadcast in
785
+
786
+ Returns:
787
+ List of all strings
788
+
789
+ Note:
790
+ Will construct and use a temporary gloo process group to minimize device to host transfers
791
+
792
+ TODO: support fixed_buffer_size
793
+ """
794
+
795
+ if not dist .is_available () or not dist .is_initialized ():
796
+ return [val ]
797
+
798
+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
799
+
800
+ # use gloo so that we avoid gpu->cpu (device to host) transfers
801
+ # with get_or_create_gloo_pg(process_group) as gloo_pg:
802
+
803
+ # Initialize buffer and buffer_length for all ranks
804
+ buffer = torch .frombuffer (val .encode ("utf-8" ), dtype = torch .uint8 ).to (device )
805
+ # use `all_gather_tensors` which handles all gathering tensors
806
+ # of same shape but different lengths (since strings may be different
807
+ # length on each rank)
808
+ buffer_strings = all_gather_tensors (buffer , group = process_group )
809
+
810
+ result = [bytes (buffer .tolist ()).decode ("utf-8" ) for buffer in buffer_strings ]
811
+
812
+ return result
813
+
814
+
775
815
@contextmanager
776
816
def get_or_create_gloo_pg (
777
817
candidate_pg : Optional [dist .ProcessGroup ] = None ,
You can’t perform that action at this time.
0 commit comments