Skip to content

Commit 763e572

Browse files
VieEeEwfacebook-github-bot
authored andcommitted
Updated in broadcast_str to use correct tensor size
Summary: Tensor size from source rank was [] before the fix while on other ranks tensor size was [1]. Broadcasting from [] to [1] is an illegal usage. The bug heppened to not cause any failures. Reviewed By: diego-urgell, JKSenthil, fduwjj Differential Revision: D77901586 fbshipit-source-id: 2be33d7a2fcd1113995fcff499b6e90b06c0abb3
1 parent 58be584 commit 763e572

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

tests/utils/test_distributed_gpu.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,17 @@ def test_spawn_multi_process(self) -> None:
8686
def test_broadcast_str(self) -> None:
8787
spawn_multi_process(2, "gloo", self._test_broadcast_str)
8888

89+
@skip_if_not_gpu
90+
@skip_if_not_distributed
91+
def test_broadcast_str_gpu(self) -> None:
92+
spawn_multi_process(2, "nccl", self._test_broadcast_str)
93+
8994
@staticmethod
9095
def _test_broadcast_str() -> None:
9196
"""
9297
Tests that test_broadcast_strworks as expected
9398
"""
94-
99+
init_from_env()
95100
val = None
96101
if dist.get_rank() == 0:
97102
val = "foo"

torchtnt/utils/distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,7 @@ def broadcast_str(
738738
# convert string to tensor
739739
buffer = torch.frombuffer(val.encode("utf-8"), dtype=torch.uint8)
740740
buffer = buffer.to(device=device)
741-
buffer_length = torch.tensor((len(buffer)), dtype=torch.int, device=device)
741+
buffer_length = torch.tensor([len(buffer)], dtype=torch.int, device=device)
742742

743743
if fixed_buffer_size is not None:
744744
if len(buffer) > fixed_buffer_size:

0 commit comments

Comments
 (0)