diff --git a/tests/utils/test_distributed_gpu.py b/tests/utils/test_distributed_gpu.py index abae65a80a..942ca2b6c5 100644 --- a/tests/utils/test_distributed_gpu.py +++ b/tests/utils/test_distributed_gpu.py @@ -86,12 +86,17 @@ def test_spawn_multi_process(self) -> None: def test_broadcast_str(self) -> None: spawn_multi_process(2, "gloo", self._test_broadcast_str) + @skip_if_not_gpu + @skip_if_not_distributed + def test_broadcast_str_gpu(self) -> None: + spawn_multi_process(2, "nccl", self._test_broadcast_str) + @staticmethod def _test_broadcast_str() -> None: """ Tests that test_broadcast_strworks as expected """ - + init_from_env() val = None if dist.get_rank() == 0: val = "foo" diff --git a/torchtnt/utils/distributed.py b/torchtnt/utils/distributed.py index f5524e7d23..385b75da10 100644 --- a/torchtnt/utils/distributed.py +++ b/torchtnt/utils/distributed.py @@ -738,7 +738,7 @@ def broadcast_str( # convert string to tensor buffer = torch.frombuffer(val.encode("utf-8"), dtype=torch.uint8) buffer = buffer.to(device=device) - buffer_length = torch.tensor((len(buffer)), dtype=torch.int, device=device) + buffer_length = torch.tensor([len(buffer)], dtype=torch.int, device=device) if fixed_buffer_size is not None: if len(buffer) > fixed_buffer_size: