@@ -103,29 +103,6 @@ def test_reduce_scatter(self, pin_layout):
103
103
for ordinal , value in results .items ():
104
104
np .testing .assert_array_equal (value , [- ordinal ])
105
105
106
- @staticmethod
107
- def _scatter ():
108
- dist .init_process_group ("xla" , init_method = 'xla://' )
109
- device = torch_xla .device ()
110
- world_size = xr .world_size ()
111
- tensors = None
112
- if xr .global_ordinal () == 0 :
113
- tensors = [
114
- torch .tensor ([i ], device = device , dtype = torch .float )
115
- for i in range (world_size )
116
- ]
117
-
118
- output_tensor = torch .tensor ([- 1 ], dtype = torch .float , device = device )
119
- dist .scatter (output_tensor , tensors , src = 0 )
120
- return output_tensor .cpu ()
121
-
122
- def test_scatter (self ):
123
- """self._scatter instantiates a list of tensors [[0], [1], ..., [n-1]]
124
- on device 0, then scatters it. Device i should therefore receive [i]."""
125
- results = pjrt .run_multiprocess (self ._scatter )
126
- for ordinal , value in results .items ():
127
- np .testing .assert_array_equal (value , [ordinal ])
128
-
129
106
@staticmethod
130
107
def _all_to_all (pin_layout ):
131
108
device = torch_xla .device ()
@@ -359,6 +336,49 @@ def test_all_to_all_single(self, use_dynamo):
359
336
expected .sort ().values ),
360
337
f"Got { val } , expected { expected } " )
361
338
339
+ @staticmethod
340
+ def _scatter ():
341
+ dist .init_process_group ("xla" , init_method = 'xla://' )
342
+ device = torch_xla .device ()
343
+ world_size = xr .world_size ()
344
+ tensors = None
345
+ if xr .global_ordinal () == 0 :
346
+ tensors = [
347
+ torch .tensor ([i ], device = device , dtype = torch .float )
348
+ for i in range (world_size )
349
+ ]
350
+
351
+ output_tensor = torch .tensor ([- 1 ], dtype = torch .float , device = device )
352
+ dist .scatter (output_tensor , tensors , src = 0 )
353
+ return output_tensor .cpu ()
354
+
355
+ def test_scatter (self ):
356
+ """self._scatter instantiates a list of tensors [[0], [1], ..., [n-1]]
357
+ on device 0, then scatters it. Device i should therefore receive [i]."""
358
+ results = pjrt .run_multiprocess (self ._scatter )
359
+ for ordinal , value in results .items ():
360
+ np .testing .assert_array_equal (value , [ordinal ])
361
+
362
+ @staticmethod
363
+ def _reduce ():
364
+ dist .init_process_group ("xla" , init_method = 'xla://' )
365
+ device = torch_xla .device ()
366
+ input = torch .tensor ([xr .global_ordinal ()],
367
+ dtype = torch .float ,
368
+ device = device )
369
+ dist .reduce (input , dst = 0 , op = dist .ReduceOp .SUM )
370
+
371
+ return input .cpu ()
372
+
373
+ def test_reduce (self ):
374
+ results = pjrt .run_multiprocess (self ._reduce )
375
+ for ordinal , value in results .items ():
376
+ if ordinal == 0 :
377
+ expected = sum (range (tpu .num_expected_global_devices ()))
378
+ else :
379
+ expected = ordinal
380
+ np .testing .assert_array_equal (value , [expected ])
381
+
362
382
363
383
if __name__ == '__main__' :
364
384
absltest .main ()
0 commit comments