Skip to content

Commit 1348545

Browse files
authored
fix(xla): convert group-local to global ranks in broadcast (#9657)
Related AWS Neuron ticket: https://t.corp.amazon.com/V1941917988/overview broadcast was passing group-local ranks directly to xm.collective_broadcast() which expects global ranks, causing data curroption in single-member process groups TEST: ``` import os import torch import torch.distributed as dist import torch_xla as xla import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp import torch_xla.runtime as xr def main(): dist.init_process_group(backend="xla") rank = dist.get_rank() world_size = dist.get_world_size() tp = dist.new_group(ranks=[rank]) tp_rank = dist.get_rank(group=tp) tp_size = dist.get_world_size(group=tp) print( f">>>> pid={os.getpid()}, rank={rank}\n" f">>> world_size={world_size}\n" f">>> tp_rank={tp_rank}, tp_size={tp_size}, tp_members={dist.get_process_group_ranks(tp)}" ) do_train, do_valid, do_test = 0.1, 0.2, 0.3 # breakpoint() flags = torch.tensor([do_train, do_valid, do_test], dtype=torch.float32, device='xla') # breakpoint() dist.broadcast(flags, rank, group=tp) print(f">>>> pid={os.getpid()}, rank={rank}\n" f">>> do_train={flags[0].item()}, do_valid={flags[1].item()}, do_test={flags[2].item()}\n" f">>> global_ordinal={xr.global_ordinal()}") if __name__ == "__main__": main() ``` Results after this fix: ``` torchrun --nproc-per-node=2 --nnodes=1 ./bug.py W0926 18:50:41.903000 1081605 torch/distributed/run.py:766] W0926 18:50:41.903000 1081605 torch/distributed/run.py:766] ***************************************** W0926 18:50:41.903000 1081605 torch/distributed/run.py:766] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W0926 18:50:41.903000 1081605 torch/distributed/run.py:766] ***************************************** >>>> pid=1081679, rank=0 >>> world_size=2 >>> tp_rank=0, tp_size=1, tp_members=[0] >>>> pid=1081680, rank=1 >>> world_size=2 >>> tp_rank=0, tp_size=1, tp_members=[1] . . . 2.19.8089.0+8ab9f450/MODULE_10344927339446294134+e30acd3a/model.neff >>>> pid=1081680, rank=1 >>> do_train=0.10000000149011612, do_valid=0.20000000298023224, do_test=0.30000001192092896 >>> global_ordinal=1 >>>> pid=1081679, rank=0 >>> do_train=0.10000000149011612, do_valid=0.20000000298023224, do_test=0.30000001192092896 ``` Now both ranks have the correct values. Previously Rank1 was all zeros.
1 parent 420adaa commit 1348545

File tree

2 files changed

+93
-1
lines changed

2 files changed

+93
-1
lines changed

test/test_torch_distributed_xla_backend.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,18 @@ def patch_world(rank, size):
4444
yield
4545

4646

47+
@contextlib.contextmanager
48+
def patch_world_with_xla_runtime(rank, size):
49+
assert isinstance(dist.group.WORLD,
50+
torch_xla.distributed.xla_backend.ProcessGroupXla)
51+
52+
with mock.patch.object(dist.group.WORLD, 'rank', return_value=rank), \
53+
mock.patch.object(dist.group.WORLD, 'size', return_value=size), \
54+
mock.patch.object(xr, 'global_ordinal', return_value=rank), \
55+
mock.patch.object(xr, 'world_size', return_value=size):
56+
yield
57+
58+
4759
class XlaBackendTest(parameterized.TestCase):
4860

4961
@classmethod
@@ -328,6 +340,81 @@ def test_unimplemented_op(self, op):
328340
with self.assertRaises(NotImplementedError):
329341
getattr(pg_xla, op)(tensor)
330342

343+
@patch_world_with_xla_runtime(rank=0, size=2)
344+
def test_broadcast_single_rank_group_rank0(self):
345+
"""Test broadcast in single-member process group for rank 0"""
346+
device = torch_xla.device()
347+
348+
with new_group_barrier_disabled():
349+
tp = dist.new_group(ranks=[0])
350+
351+
# Create flags tensor with initial values (simulating rank 0's values)
352+
flags = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32, device=device)
353+
354+
# Broadcast within the single-member group (should be a no-op but shouldn't crash)
355+
dist.broadcast(flags, src=0, group=tp)
356+
357+
# Values should remain unchanged since it's a single-member group
358+
self.assertAlmostEqual(flags[0].item(), 0.1, places=5)
359+
self.assertAlmostEqual(flags[1].item(), 0.2, places=5)
360+
self.assertAlmostEqual(flags[2].item(), 0.3, places=5)
361+
362+
# Verify the process group properties
363+
self.assertEqual(dist.get_rank(group=tp), 0)
364+
self.assertEqual(dist.get_world_size(group=tp), 1)
365+
366+
@patch_world_with_xla_runtime(rank=1, size=2)
367+
def test_broadcast_single_rank_group_rank1(self):
368+
"""Test broadcast in single-member process group for rank 1"""
369+
device = torch_xla.device()
370+
371+
with new_group_barrier_disabled():
372+
tp = dist.new_group(ranks=[1])
373+
374+
# Create flags tensor with initial values (simulating rank 1's values)
375+
flags = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32, device=device)
376+
377+
# Broadcast within the single-member group (should be a no-op but shouldn't crash)
378+
dist.broadcast(flags, src=1, group=tp)
379+
380+
# Values should remain unchanged since it's a single-member group
381+
self.assertAlmostEqual(flags[0].item(), 0.1, places=5)
382+
self.assertAlmostEqual(flags[1].item(), 0.2, places=5)
383+
self.assertAlmostEqual(flags[2].item(), 0.3, places=5)
384+
385+
# Verify the process group properties
386+
self.assertEqual(dist.get_rank(group=tp),
387+
0) # Local rank in single-member group is 0
388+
self.assertEqual(dist.get_world_size(group=tp), 1)
389+
390+
@patch_world_with_xla_runtime(rank=0, size=2)
391+
def test_broadcast_global_rank_conversion_single_member(self):
392+
"""Test that global rank conversion works correctly for single-member groups"""
393+
device = torch_xla.device()
394+
395+
# Create single-member group for rank 0
396+
with new_group_barrier_disabled():
397+
tp = dist.new_group(ranks=[0])
398+
399+
flags = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32, device=device)
400+
401+
# Get the ProcessGroupXla instance to test directly
402+
self.assertIsInstance(tp, torch_xla.distributed.xla_backend.ProcessGroupXla)
403+
404+
# Test broadcast options - local rank 0 should map to global rank 0
405+
opts = dist.BroadcastOptions()
406+
opts.rootRank = 0
407+
opts.rootTensor = 0
408+
409+
# This should work without variable name errors
410+
work = tp.broadcast([flags], opts)
411+
self.assertIsNotNone(work)
412+
413+
# Values should be preserved
414+
self.assertAlmostEqual(flags[0].item(), 0.1, places=5)
415+
self.assertAlmostEqual(flags[1].item(), 0.2, places=5)
416+
self.assertAlmostEqual(flags[2].item(), 0.3, places=5)
417+
331418

332419
if __name__ == '__main__':
333420
if xr.device_type() != 'CPU':

torch_xla/distributed/xla_backend.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,14 @@ def allgather_coalesced(self, output_tensors_list, input_tensors, opts=None):
131131
# Call site:
132132
# https://github.com/pytorch/pytorch/blob/release/1.10/torch/distributed/distributed_c10d.py#L1129
133133
def broadcast(self, tensors, opts):
134+
import torch.distributed as dist
135+
134136
root_tensor = tensors[opts.rootTensor]
137+
# Convert group local rank to global rank for xla collectives
138+
group_source = opts.rootRank
139+
global_src = dist.get_global_rank(self, group_source)
135140
xm.collective_broadcast([root_tensor],
136-
opts.rootRank,
141+
global_src,
137142
groups=self._mesh,
138143
pin_layout=False)
139144

0 commit comments

Comments
 (0)