Skip to content

Commit 173a770

Browse files
authored
implement collective reduce op (#9437)
1 parent 51518f9 commit 173a770

File tree

3 files changed

+63
-30
lines changed

3 files changed

+63
-30
lines changed

test/pjrt/test_collective_ops_tpu.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -103,29 +103,6 @@ def test_reduce_scatter(self, pin_layout):
103103
for ordinal, value in results.items():
104104
np.testing.assert_array_equal(value, [-ordinal])
105105

106-
@staticmethod
107-
def _scatter():
108-
dist.init_process_group("xla", init_method='xla://')
109-
device = torch_xla.device()
110-
world_size = xr.world_size()
111-
tensors = None
112-
if xr.global_ordinal() == 0:
113-
tensors = [
114-
torch.tensor([i], device=device, dtype=torch.float)
115-
for i in range(world_size)
116-
]
117-
118-
output_tensor = torch.tensor([-1], dtype=torch.float, device=device)
119-
dist.scatter(output_tensor, tensors, src=0)
120-
return output_tensor.cpu()
121-
122-
def test_scatter(self):
123-
"""self._scatter instantiates a list of tensors [[0], [1], ..., [n-1]]
124-
on device 0, then scatters it. Device i should therefore receive [i]."""
125-
results = pjrt.run_multiprocess(self._scatter)
126-
for ordinal, value in results.items():
127-
np.testing.assert_array_equal(value, [ordinal])
128-
129106
@staticmethod
130107
def _all_to_all(pin_layout):
131108
device = torch_xla.device()
@@ -359,6 +336,49 @@ def test_all_to_all_single(self, use_dynamo):
359336
expected.sort().values),
360337
f"Got {val}, expected {expected}")
361338

339+
@staticmethod
340+
def _scatter():
341+
dist.init_process_group("xla", init_method='xla://')
342+
device = torch_xla.device()
343+
world_size = xr.world_size()
344+
tensors = None
345+
if xr.global_ordinal() == 0:
346+
tensors = [
347+
torch.tensor([i], device=device, dtype=torch.float)
348+
for i in range(world_size)
349+
]
350+
351+
output_tensor = torch.tensor([-1], dtype=torch.float, device=device)
352+
dist.scatter(output_tensor, tensors, src=0)
353+
return output_tensor.cpu()
354+
355+
def test_scatter(self):
356+
"""self._scatter instantiates a list of tensors [[0], [1], ..., [n-1]]
357+
on device 0, then scatters it. Device i should therefore receive [i]."""
358+
results = pjrt.run_multiprocess(self._scatter)
359+
for ordinal, value in results.items():
360+
np.testing.assert_array_equal(value, [ordinal])
361+
362+
@staticmethod
363+
def _reduce():
364+
dist.init_process_group("xla", init_method='xla://')
365+
device = torch_xla.device()
366+
input = torch.tensor([xr.global_ordinal()],
367+
dtype=torch.float,
368+
device=device)
369+
dist.reduce(input, dst=0, op=dist.ReduceOp.SUM)
370+
371+
return input.cpu()
372+
373+
def test_reduce(self):
374+
results = pjrt.run_multiprocess(self._reduce)
375+
for ordinal, value in results.items():
376+
if ordinal == 0:
377+
expected = sum(range(tpu.num_expected_global_devices()))
378+
else:
379+
expected = ordinal
380+
np.testing.assert_array_equal(value, [expected])
381+
362382

363383
if __name__ == '__main__':
364384
absltest.main()

test/test_torch_distributed_xla_backend.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,6 @@ def test_barrier(self):
356356
dist.barrier()
357357

358358
@parameterized.parameters(
359-
'reduce',
360359
'allreduce_coalesced',
361360
'alltoall',
362361
'gather',

torch_xla/distributed/xla_backend.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import torch
22
import torch.distributed as dist
3+
import torch_xla
34
import torch_xla.core.xla_model as xm
45
import torch_xla.runtime as xr
56
from torch_xla._internal import rendezvous
67
import logging
78
import os
8-
from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions
9+
from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions, ReduceOptions
910

1011

1112
def _create_xla_process_group(prefix_store, rank, size, timeout):
@@ -224,11 +225,24 @@ def _reduce_scatter_base(self, output_tensor, input_tensor, opts):
224225
def barrier(self, opts):
225226
return _ret_work([])
226227

227-
# Call site:
228-
# https://github.com/pytorch/pytorch/blob/70f57bcb1e45d21532bdb1c44d3aab018d1cbe88/torch/distributed/distributed_c10d.py#L1417
229-
# `reduce` is not needed by DeepSpeed for now.
230-
def reduce(self, *args):
231-
raise NotImplementedError
228+
# Called by torch.distributed.reduce. Call site example:
229+
# https://github.com/pytorch/pytorch/blob/v2.7.1/torch/distributed/distributed_c10d.py#L2925
230+
# Tensors are reduced but result is only saved on dst device.
231+
# Input tensor is unchanged on all other devices.
232+
# This is an inefficient operation. In order to avoid XLA deadlocks it
233+
# performs redundant reductions on all devices and materializes the result.
234+
def reduce(self, tensors: list[torch.Tensor], opts: ReduceOptions):
235+
rank = xr.global_ordinal()
236+
dst = opts.rootRank
237+
reduce_type = self._get_reduce_type(opts.reduceOp)
238+
for tensor in tensors:
239+
result = xm.all_reduce(reduce_type, inputs=tensor)
240+
torch_xla.sync()
241+
242+
if rank == dst:
243+
tensor.copy_(result)
244+
245+
return _ret_work(tensors)
232246

233247
def allreduce_coalesced(self, *args):
234248
raise NotImplementedError

0 commit comments

Comments
 (0)