Skip to content

Commit 17b9c61

Browse files
kwen2501pytorchmergebot
authored andcommitted
[a2av] not returning out tensor from ops (pytorch#159435)
torch.compile of `all_to_all_vdev_2d` hits the following error: ``` torch._dynamo.exc.BackendCompilerFailed: backend='aot_eager' raised: RuntimeError: Found a custom (non-ATen) operator whose output has alias annotations: symm_mem::all_to_all_vdev_2d(Tensor input, Tensor(a!) out, Tensor in_splits, Tensor(a!) out_splits_offsets, str group_name, int? major_align=None) -> Tensor(a!). We only support functionalizing operators whose outputs do not have alias annotations (e.g. 'Tensor(a)' is a Tensor with an alias annotation whereas 'Tensor' is a Tensor without. The '(a)' is the alias annotation). The alias annotation specifies that the output Tensor shares storage with an input that has the same annotation. Please check if (1) the output needs to be an output (if not, don't return it), (2) if the output doesn't share storage with any inputs, then delete the alias annotation. (3) if the output indeed shares storage with an input, then add a .clone() before returning it to prevent storage sharing and then delete the alias annotation. Otherwise, please file an issue on GitHub. ``` This PR selects option (1). Pull Request resolved: pytorch#159435 Approved by: https://github.com/ngimel, https://github.com/xmfan
1 parent d3ce450 commit 17b9c61

File tree

4 files changed

+29
-8
lines changed

4 files changed

+29
-8
lines changed

torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,9 +342,9 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
342342
m.def(
343343
"all_to_all_vdev(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name) -> Tensor(a!)");
344344
m.def(
345-
"all_to_all_vdev_2d(Tensor input, Tensor(a!) out, Tensor in_splits, Tensor(a!) out_splits_offsets, str group_name, int? major_align=None) -> Tensor(a!)");
345+
"all_to_all_vdev_2d(Tensor input, Tensor(a!) out, Tensor in_splits, Tensor(a!) out_splits_offsets, str group_name, int? major_align=None) -> ()");
346346
m.def(
347-
"all_to_all_vdev_2d_offset(Tensor input, Tensor(a!) out, Tensor in_splits_offsets, Tensor(a!) out_splits_offsets, str group_name) -> Tensor(a!)");
347+
"all_to_all_vdev_2d_offset(Tensor input, Tensor(a!) out, Tensor in_splits_offsets, Tensor(a!) out_splits_offsets, str group_name) -> ()");
348348
}
349349

350350
TORCH_LIBRARY_IMPL(symm_mem, Meta, m) {

torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_split
539539
#endif
540540
}
541541

542-
at::Tensor all_to_all_vdev_2d(
542+
void all_to_all_vdev_2d(
543543
at::Tensor& input,
544544
at::Tensor& out,
545545
at::Tensor& in_splits,
@@ -685,10 +685,9 @@ at::Tensor all_to_all_vdev_2d(
685685
args1,
686686
0,
687687
stream);
688-
return out;
689688
}
690689

691-
at::Tensor all_to_all_vdev_2d_offset(
690+
void all_to_all_vdev_2d_offset(
692691
at::Tensor& input,
693692
at::Tensor& out,
694693
at::Tensor& in_splits_offsets,
@@ -819,7 +818,6 @@ at::Tensor all_to_all_vdev_2d_offset(
819818
args1,
820819
0,
821820
stream);
822-
return out;
823821
}
824822
} // namespace c10d::nvshmem_extension
825823

torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,15 @@ at::Tensor all_to_all_vdev(
3838
at::Tensor& in_out_splits,
3939
std::string group_name);
4040

41-
at::Tensor all_to_all_vdev_2d(
41+
void all_to_all_vdev_2d(
4242
at::Tensor& input,
4343
at::Tensor& out,
4444
at::Tensor& in_splits,
4545
at::Tensor& out_splits_offsets,
4646
std::string group_name,
4747
std::optional<int64_t> major_align = std::nullopt);
4848

49-
at::Tensor all_to_all_vdev_2d_offset(
49+
void all_to_all_vdev_2d_offset(
5050
at::Tensor& input,
5151
at::Tensor& out,
5252
at::Tensor& in_splits_offsets,

torch/distributed/_symmetric_memory/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1609,6 +1609,29 @@ def _low_contention_reduce_scatter(
16091609
)
16101610

16111611

1612+
@torch.library.impl(lib, "all_to_all_vdev_2d", "Meta")
1613+
def _all_to_all_vdev_2d_meta(
1614+
input: torch.Tensor,
1615+
out: torch.Tensor,
1616+
in_splits: torch.Tensor,
1617+
out_splits_offsets: torch.Tensor,
1618+
group_name: str,
1619+
major_align: Optional[int] = None,
1620+
) -> None:
1621+
return None
1622+
1623+
1624+
@torch.library.impl(lib, "all_to_all_vdev_2d_offset", "Meta")
1625+
def _all_to_all_vdev_2d_offset_meta(
1626+
input: torch.Tensor,
1627+
out: torch.Tensor,
1628+
in_splits_offsets: torch.Tensor,
1629+
out_splits_offsets: torch.Tensor,
1630+
group_name: str,
1631+
) -> None:
1632+
return None
1633+
1634+
16121635
# =============================================================================
16131636
# User-facing APIs
16141637
# =============================================================================

0 commit comments

Comments
 (0)