Skip to content

Commit fa3c38c

Browse files
zeshengzongpytorchmergebot
authored andcommitted
Add tensor overlap check for cross (pytorch#154999)
Fixes pytorch#132031 ## Test Result ```python In [1]: import torch ...: torch.manual_seed(0) ...: torch.cuda.manual_seed(0) ...: a = torch.randn(3, 4) ...: b = torch.randn(3, 4) ...: torch.cross(a, b, out=a) --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) Cell In[1], line 6 4 a = torch.randn(3, 4) 5 b = torch.randn(3, 4) ----> 6 torch.cross(a, b, out=a) RuntimeError: unsupported operation: some elements of the input tensor and the written-to tensor refer to a single memory location. Please clone() the tensor before performing the operation. ``` Pull Request resolved: pytorch#154999 Approved by: https://github.com/lezcano
1 parent 5b65628 commit fa3c38c

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

aten/src/ATen/native/Cross.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <ATen/WrapDimUtils.h>
77
#include <ATen/ExpandUtils.h>
88
#include <ATen/native/Resize.h>
9+
#include <ATen/MemoryOverlap.h>
910

1011

1112
#ifndef AT_PER_OPERATOR_HEADERS
@@ -77,6 +78,9 @@ Tensor & cross_out(const Tensor & input, const Tensor & other, const std::option
7778

7879
TORCH_IMPL_FUNC(linalg_cross_out)
7980
(const Tensor & input, const Tensor & other, int64_t dim, const Tensor & out) {
81+
at::assert_no_internal_overlap(out);
82+
at::assert_no_overlap(out, input);
83+
at::assert_no_overlap(out, other);
8084
dim = maybe_wrap_dim(dim, input.dim());
8185
auto out_size = out.sizes();
8286
Tensor input_broadcasted = input.expand(out_size);

test/test_linalg.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6022,6 +6022,18 @@ def test_linalg_cross_with_and_without_dim(self, device, dtype):
60226022
self.assertEqual(res1, res2)
60236023
self.assertEqual(res1, res3)
60246024

6025+
def test_cross_error(self, device):
6026+
x = torch.randn(4, 3, device=device)
6027+
y = torch.randn(4, 3, device=device)
6028+
with self.assertRaisesRegex(RuntimeError, "input tensor and the written-to tensor refer to a single memory location"):
6029+
torch.cross(x, y, out=x)
6030+
with self.assertRaisesRegex(RuntimeError, "input tensor and the written-to tensor refer to a single memory location"):
6031+
torch.cross(y, x, out=x)
6032+
with self.assertRaisesRegex(RuntimeError, "input tensor and the written-to tensor refer to a single memory location"):
6033+
torch.linalg.cross(x, y, out=x)
6034+
with self.assertRaisesRegex(RuntimeError, "input tensor and the written-to tensor refer to a single memory location"):
6035+
torch.linalg.cross(y, x, out=x)
6036+
60256037
def test_renorm(self, device):
60266038
m1 = torch.randn(20, 20, device=device) # big enough to exercise vectorized path
60276039
res1 = torch.tensor((), device=device)

0 commit comments

Comments
 (0)