Skip to content

Commit f5bf0c1

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
add all gather string util (#1004)
Summary: Pull Request resolved: #1004 Reviewed By: galrotem Differential Revision: D75021212 fbshipit-source-id: fcb01621f6a4d8d8d7ef4166363bf4046bbbefc8
1 parent cb31137 commit f5bf0c1

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

tests/utils/test_distributed.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch.distributed import ProcessGroup
2121
from torchtnt.utils.distributed import (
2222
_validate_global_rank_world_size,
23+
all_gather_str,
2324
all_gather_tensors,
2425
broadcast_str,
2526
destroy_process_group,
@@ -610,3 +611,29 @@ def _test_broadcast_str() -> None:
610611

611612
tc = unittest.TestCase()
612613
tc.assertEqual(broadcasted_val, "foo")
614+
615+
@skip_if_not_distributed
616+
def test_all_gather_str(self) -> None:
617+
backend = "gloo"
618+
if torch.cuda.is_available():
619+
backend = "nccl"
620+
621+
spawn_multi_process(2, backend, self._test_all_gather_str)
622+
623+
@staticmethod
624+
def _test_all_gather_str() -> None:
625+
if torch.cuda.is_available():
626+
torch.cuda.set_device(dist.get_rank())
627+
628+
val = None
629+
if dist.get_rank() == 0:
630+
val = "foo"
631+
else:
632+
val = "barzoo"
633+
634+
# Test case 1: fixed_buffer_size == len(val)
635+
vals = all_gather_str(val)
636+
637+
tc = unittest.TestCase()
638+
tc.assertEqual(vals[0], "foo")
639+
tc.assertEqual(vals[1], "barzoo")

torchtnt/utils/distributed.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,6 +772,46 @@ def broadcast_str(
772772
return string
773773

774774

775+
def all_gather_str(
776+
val: str, process_group: Optional[dist.ProcessGroup] = None
777+
) -> List[str]:
778+
"""
779+
Optimized all gather-ing string without invoking all_gather_object
780+
which is subject to hang issues on nccl.
781+
782+
Args:
783+
val: string to include in all_gather
784+
process_group: the process group to broadcast in
785+
786+
Returns:
787+
List of all strings
788+
789+
Note:
790+
Will construct and use a temporary gloo process group to minimize device to host transfers
791+
792+
TODO: support fixed_buffer_size
793+
"""
794+
795+
if not dist.is_available() or not dist.is_initialized():
796+
return [val]
797+
798+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
799+
800+
# use gloo so that we avoid gpu->cpu (device to host) transfers
801+
# with get_or_create_gloo_pg(process_group) as gloo_pg:
802+
803+
# Initialize buffer and buffer_length for all ranks
804+
buffer = torch.frombuffer(val.encode("utf-8"), dtype=torch.uint8).to(device)
805+
# use `all_gather_tensors` which handles all gathering tensors
806+
# of same shape but different lengths (since strings may be different
807+
# length on each rank)
808+
buffer_strings = all_gather_tensors(buffer, group=process_group)
809+
810+
result = [bytes(buffer.tolist()).decode("utf-8") for buffer in buffer_strings]
811+
812+
return result
813+
814+
775815
@contextmanager
776816
def get_or_create_gloo_pg(
777817
candidate_pg: Optional[dist.ProcessGroup] = None,

0 commit comments

Comments
 (0)