Skip to content

Commit aaf5615

Browse files
dccipytorchmergebot
authored andcommitted
[cpu/sorting] Throw an error when trying to sort complex numbers. (pytorch#144113)
It doesn't really make sense to sort complex numbers as they are not comparable. Fixes pytorch#129296 Pull Request resolved: pytorch#144113 Approved by: https://github.com/malfet
1 parent 78eded8 commit aaf5615

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

aten/src/ATen/native/Sorting.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,12 @@ TORCH_META_FUNC2(sort, stable)
7474
(const Tensor& self, std::optional<bool> stable, int64_t dim, bool descending) {
7575
maybe_wrap_dim(dim, self.dim());
7676

77+
const auto self_dtype = self.dtype();
78+
TORCH_CHECK_VALUE(
79+
self_dtype != ScalarType::ComplexFloat &&
80+
self_dtype != ScalarType::ComplexDouble,
81+
"Sort currently does not support complex dtypes on CPU.");
82+
7783
// See issue: https://github.com/pytorch/pytorch/issues/65863
7884
// Strides should be dense, so as not to allocate too much memory.
7985
// We either use 'self' strides, or infer dense strides from them.

test/test_sort_and_select.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,14 @@ def test_sort_stable_none(self):
175175
y = x.sort(stable=None).values
176176
self.assertTrue(torch.all(y == torch.ones(10)).item())
177177

178+
@onlyCPU
179+
def test_complex_unsupported_cpu(self):
180+
x = torch.tensor([3.0 + 2j, 4.0 + 3j])
181+
with self.assertRaisesRegex(
182+
ValueError, "Sort currently does not support complex dtypes on CPU."
183+
):
184+
torch.sort(input=x)
185+
178186
@onlyCUDA
179187
def test_sort_large_slice(self, device):
180188
# tests direct cub path

0 commit comments

Comments
 (0)