@@ -44,6 +44,18 @@ def patch_world(rank, size):
44
44
yield
45
45
46
46
47
+ @contextlib .contextmanager
48
+ def patch_world_with_xla_runtime (rank , size ):
49
+ assert isinstance (dist .group .WORLD ,
50
+ torch_xla .distributed .xla_backend .ProcessGroupXla )
51
+
52
+ with mock .patch .object (dist .group .WORLD , 'rank' , return_value = rank ), \
53
+ mock .patch .object (dist .group .WORLD , 'size' , return_value = size ), \
54
+ mock .patch .object (xr , 'global_ordinal' , return_value = rank ), \
55
+ mock .patch .object (xr , 'world_size' , return_value = size ):
56
+ yield
57
+
58
+
47
59
class XlaBackendTest (parameterized .TestCase ):
48
60
49
61
@classmethod
@@ -328,6 +340,81 @@ def test_unimplemented_op(self, op):
328
340
with self .assertRaises (NotImplementedError ):
329
341
getattr (pg_xla , op )(tensor )
330
342
343
+ @patch_world_with_xla_runtime (rank = 0 , size = 2 )
344
+ def test_broadcast_single_rank_group_rank0 (self ):
345
+ """Test broadcast in single-member process group for rank 0"""
346
+ device = torch_xla .device ()
347
+
348
+ with new_group_barrier_disabled ():
349
+ tp = dist .new_group (ranks = [0 ])
350
+
351
+ # Create flags tensor with initial values (simulating rank 0's values)
352
+ flags = torch .tensor ([0.1 , 0.2 , 0.3 ], dtype = torch .float32 , device = device )
353
+
354
+ # Broadcast within the single-member group (should be a no-op but shouldn't crash)
355
+ dist .broadcast (flags , src = 0 , group = tp )
356
+
357
+ # Values should remain unchanged since it's a single-member group
358
+ self .assertAlmostEqual (flags [0 ].item (), 0.1 , places = 5 )
359
+ self .assertAlmostEqual (flags [1 ].item (), 0.2 , places = 5 )
360
+ self .assertAlmostEqual (flags [2 ].item (), 0.3 , places = 5 )
361
+
362
+ # Verify the process group properties
363
+ self .assertEqual (dist .get_rank (group = tp ), 0 )
364
+ self .assertEqual (dist .get_world_size (group = tp ), 1 )
365
+
366
+ @patch_world_with_xla_runtime (rank = 1 , size = 2 )
367
+ def test_broadcast_single_rank_group_rank1 (self ):
368
+ """Test broadcast in single-member process group for rank 1"""
369
+ device = torch_xla .device ()
370
+
371
+ with new_group_barrier_disabled ():
372
+ tp = dist .new_group (ranks = [1 ])
373
+
374
+ # Create flags tensor with initial values (simulating rank 1's values)
375
+ flags = torch .tensor ([0.1 , 0.2 , 0.3 ], dtype = torch .float32 , device = device )
376
+
377
+ # Broadcast within the single-member group (should be a no-op but shouldn't crash)
378
+ dist .broadcast (flags , src = 1 , group = tp )
379
+
380
+ # Values should remain unchanged since it's a single-member group
381
+ self .assertAlmostEqual (flags [0 ].item (), 0.1 , places = 5 )
382
+ self .assertAlmostEqual (flags [1 ].item (), 0.2 , places = 5 )
383
+ self .assertAlmostEqual (flags [2 ].item (), 0.3 , places = 5 )
384
+
385
+ # Verify the process group properties
386
+ self .assertEqual (dist .get_rank (group = tp ),
387
+ 0 ) # Local rank in single-member group is 0
388
+ self .assertEqual (dist .get_world_size (group = tp ), 1 )
389
+
390
+ @patch_world_with_xla_runtime (rank = 0 , size = 2 )
391
+ def test_broadcast_global_rank_conversion_single_member (self ):
392
+ """Test that global rank conversion works correctly for single-member groups"""
393
+ device = torch_xla .device ()
394
+
395
+ # Create single-member group for rank 0
396
+ with new_group_barrier_disabled ():
397
+ tp = dist .new_group (ranks = [0 ])
398
+
399
+ flags = torch .tensor ([0.1 , 0.2 , 0.3 ], dtype = torch .float32 , device = device )
400
+
401
+ # Get the ProcessGroupXla instance to test directly
402
+ self .assertIsInstance (tp , torch_xla .distributed .xla_backend .ProcessGroupXla )
403
+
404
+ # Test broadcast options - local rank 0 should map to global rank 0
405
+ opts = dist .BroadcastOptions ()
406
+ opts .rootRank = 0
407
+ opts .rootTensor = 0
408
+
409
+ # This should work without variable name errors
410
+ work = tp .broadcast ([flags ], opts )
411
+ self .assertIsNotNone (work )
412
+
413
+ # Values should be preserved
414
+ self .assertAlmostEqual (flags [0 ].item (), 0.1 , places = 5 )
415
+ self .assertAlmostEqual (flags [1 ].item (), 0.2 , places = 5 )
416
+ self .assertAlmostEqual (flags [2 ].item (), 0.3 , places = 5 )
417
+
331
418
332
419
if __name__ == '__main__' :
333
420
if xr .device_type () != 'CPU' :
0 commit comments