|
6 | 6 | from torch_xla._internal import rendezvous
|
7 | 7 | import logging
|
8 | 8 | import os
|
9 |
| -from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions, ReduceOptions |
| 9 | +from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions, AllToAllOptions, ReduceOptions |
10 | 10 |
|
11 | 11 |
|
12 | 12 | def _create_xla_process_group(prefix_store, rank, size, timeout):
|
@@ -247,8 +247,24 @@ def reduce(self, tensors: list[torch.Tensor], opts: ReduceOptions):
|
247 | 247 | def allreduce_coalesced(self, *args):
|
248 | 248 | raise NotImplementedError
|
249 | 249 |
|
250 |
| - def alltoall(self, *args): |
251 |
| - raise NotImplementedError |
| 250 | + # Called by torch.distributed.all_to_all. Call site example: |
| 251 | + # https://github.com/pytorch/pytorch/blob/v2.7.1/torch/distributed/distributed_c10d.py#L4577 |
| 252 | + # The difference between this and all_to_all_single is that this works |
| 253 | + # on a list of tensors while all_to_all_single works on a single tensor |
| 254 | + # and splits/concats along dimension 0. |
| 255 | + def alltoall(self, output_tensor_list: list[torch.Tensor], |
| 256 | + input_tensor_list: list[torch.Tensor], opts: AllToAllOptions): |
| 257 | + stacked_inputs = torch.stack(input_tensor_list, dim=0) |
| 258 | + split_count = len(input_tensor_list) |
| 259 | + stacked_results = xm.all_to_all( |
| 260 | + stacked_inputs, |
| 261 | + split_dimension=0, |
| 262 | + concat_dimension=0, |
| 263 | + split_count=split_count) |
| 264 | + results = torch.chunk(stacked_results, split_count, dim=0) |
| 265 | + for result, output_tensor in zip(results, output_tensor_list): |
| 266 | + output_tensor.copy_(result.squeeze(dim=0)) |
| 267 | + return _ret_work(output_tensor_list) |
252 | 268 |
|
253 | 269 | # handle the nondynamo path when call torch.distributed.all_to_all_single
|
254 | 270 | # call from https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/torch/distributed/distributed_c10d.py#L3996
|
|
0 commit comments