Skip to content

Commit d050dcd

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
add broadcast_str util (#986)
Summary: Pull Request resolved: #986 Reviewed By: diego-urgell Differential Revision: D71653630 fbshipit-source-id: 35e1691b31482c0a395c5032da08f5cf869c7cfa
1 parent 3dfcb7d commit d050dcd

File tree

3 files changed

+104
-0
lines changed

3 files changed

+104
-0
lines changed

tests/utils/test_distributed.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from torchtnt.utils.distributed import (
2222
_validate_global_rank_world_size,
2323
all_gather_tensors,
24+
broadcast_str,
2425
destroy_process_group,
2526
get_file_init_method,
2627
get_global_rank,
@@ -571,3 +572,22 @@ def get_backend(_) -> str:
571572
raise Exception("Test Exception")
572573

573574
mock_destroy_process_group.assert_called_once_with(pg)
575+
576+
@skip_if_not_distributed
577+
def test_broadcast_str(self) -> None:
578+
spawn_multi_process(2, "gloo", self._test_broadcast_str)
579+
580+
@staticmethod
581+
def _test_broadcast_str() -> None:
582+
"""
583+
Tests that test_broadcast_str works as expected
584+
"""
585+
586+
val = None
587+
if dist.get_rank() == 0:
588+
val = "foo"
589+
590+
broadcasted_val = broadcast_str(val)
591+
592+
tc = unittest.TestCase()
593+
tc.assertEqual(broadcasted_val, "foo")

tests/utils/test_distributed_gpu.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torchtnt.utils.device import get_device_from_env
1515
from torchtnt.utils.distributed import (
1616
all_gather_tensors,
17+
broadcast_str,
1718
get_global_rank,
1819
get_local_rank,
1920
PGWrapper,
@@ -79,3 +80,23 @@ def _test_method(offset_arg: int, offset_kwarg: int) -> int:
7980
def test_spawn_multi_process(self) -> None:
8081
mp_list = spawn_multi_process(2, "nccl", self._test_method, 3, offset_kwarg=2)
8182
self.assertEqual(mp_list, [1, 2])
83+
84+
@skip_if_not_gpu
85+
@skip_if_not_distributed
86+
def test_broadcast_str(self) -> None:
87+
spawn_multi_process(2, "gloo", self._test_broadcast_str)
88+
89+
@staticmethod
90+
def _test_broadcast_str() -> None:
91+
"""
92+
Tests that test_broadcast_strworks as expected
93+
"""
94+
95+
val = None
96+
if dist.get_rank() == 0:
97+
val = "foo"
98+
99+
broadcasted_val = broadcast_str(val)
100+
101+
tc = unittest.TestCase()
102+
tc.assertEqual(broadcasted_val, "foo")

torchtnt/utils/distributed.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,69 @@ def wrapper(*args: TParams.args, **kwargs: TParams.kwargs) -> TReturn:
684684
return wrapper
685685

686686

687+
def broadcast_str(
688+
val: Optional[str],
689+
src: int = 0,
690+
process_group: Optional[dist.ProcessGroup] = None,
691+
) -> Optional[str]:
692+
"""
693+
Broadcasts a string from a source rank to all other ranks in a process group.
694+
Serializes string as sequence of uint8 and broadcasts as a tensor. This avoids
695+
issues with broadcast_object_list and related apis which use pickle to serialize objects.
696+
697+
Args:
698+
val: the string to broadcast
699+
src: the source rank to broadcast from
700+
process_group: the process group to broadcast in. Defaults to the WORLD process group.
701+
702+
Returns:
703+
The broadcasted string.
704+
705+
Note:
706+
This function issues two collective calls, one to broadcast the size of the serialized string and
707+
one to broadcast the string itself. This can theoretically be limited to one collective call
708+
by hardcoding maximum buffer size to use, and filling unused buffer slots with preselected
709+
null tokens. However, this is not implemented to avoid unnecessary complexity.
710+
"""
711+
if not dist.is_available() or not dist.is_initialized():
712+
return val
713+
714+
rank = dist.get_rank(group=process_group)
715+
716+
# device to use when broadcasting the string
717+
device = torch.device(
718+
torch.cuda.current_device()
719+
if dist.get_backend(process_group) == "nccl"
720+
else "cpu"
721+
)
722+
723+
# dummy instantiation to keep pyre happy
724+
buffer = torch.empty((1), dtype=torch.uint8)
725+
buffer_length = torch.empty((1), dtype=torch.int, device=device)
726+
if rank == src:
727+
assert (
728+
val is not None
729+
), "Source rank must provide a string to broadcast, got None"
730+
731+
# convert string to tensor
732+
buffer = torch.frombuffer(val.encode("utf-8"), dtype=torch.uint8)
733+
buffer = buffer.to(device=device)
734+
buffer_length = torch.tensor((len(buffer)), dtype=torch.int, device=device)
735+
736+
# first broadcast the buffer length so receiving ranks can allocate the correct amount of memory
737+
dist.broadcast(buffer_length, src=src, group=process_group)
738+
if rank != src:
739+
size = int(buffer_length.item())
740+
buffer = torch.empty((size), dtype=torch.uint8, device=device)
741+
742+
# now broadcast string
743+
dist.broadcast(buffer, src=src, group=process_group)
744+
745+
# convert tensor to string
746+
string = bytes(buffer.tolist()).decode(encoding="utf-8")
747+
return string
748+
749+
687750
@contextmanager
688751
def get_or_create_gloo_pg(
689752
candidate_pg: Optional[dist.ProcessGroup] = None,

0 commit comments

Comments
 (0)