Skip to content

Commit 92082f9

Browse files
committed
faster coalesce if no value provided
1 parent 3c7253a commit 92082f9

File tree

6 files changed

+79
-8
lines changed

6 files changed

+79
-8
lines changed

cuda/unique.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#include <torch/torch.h>
2+
3+
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
4+
5+
std::tuple<at::Tensor, at::Tensor> unique_cuda(at::Tensor src);
6+
7+
std::tuple<at::Tensor, at::Tensor> unique(at::Tensor src) {
8+
CHECK_CUDA(src);
9+
return unique_cuda(src);
10+
}
11+
12+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
13+
m.def("unique", &unique, "Unique (CUDA)");
14+
}

cuda/unique_kernel.cu

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#include <ATen/ATen.h>
2+
3+
#define THREADS 1024
4+
#define BLOCKS(N) (N + THREADS - 1) / THREADS
5+
6+
template <typename scalar_t>
7+
__global__ void unique_cuda_kernel(scalar_t *__restrict__ src, uint8_t *mask,
8+
size_t numel) {
9+
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
10+
const size_t stride = blockDim.x * gridDim.x;
11+
for (ptrdiff_t i = index; i < numel; i += stride) {
12+
if (i == 0 || src[i] != src[i - 1]) {
13+
mask[i] = 1;
14+
}
15+
}
16+
}
17+
18+
std::tuple<at::Tensor, at::Tensor> unique_cuda(at::Tensor src) {
19+
at::Tensor perm;
20+
std::tie(src, perm) = src.sort();
21+
22+
auto mask = at::zeros(src.numel(), src.type().toScalarType(at::kByte));
23+
AT_DISPATCH_ALL_TYPES(src.type(), "grid_cuda_kernel", [&] {
24+
unique_cuda_kernel<scalar_t><<<BLOCKS(src.numel()), THREADS>>>(
25+
src.data<scalar_t>(), mask.data<uint8_t>(), src.numel());
26+
});
27+
28+
src = src.masked_select(mask);
29+
perm = perm.masked_select(mask);
30+
31+
return std::make_tuple(src, perm);
32+
}

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
'spspmm_cuda',
1818
['cuda/spspmm.cpp', 'cuda/spspmm_kernel.cu'],
1919
extra_link_args=['-lcusparse'],
20-
)
20+
),
21+
CUDAExtension('unique_cuda',
22+
['cuda/unique.cpp', 'cuda/unique_kernel.cu'])
2123
]
2224
cmdclass['build_ext'] = BuildExtension
2325

torch_sparse/coalesce.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import torch
22
import torch_scatter
33

4+
from .utils.unique import unique
5+
46

57
def coalesce(index, value, m, n, op='add', fill_value=0):
68
"""Row-wise sorts :obj:`value` and removes duplicate entries. Duplicate
@@ -23,16 +25,20 @@ def coalesce(index, value, m, n, op='add', fill_value=0):
2325

2426
row, col = index
2527

26-
unique, inv = torch.unique(row * n + col, sorted=True, return_inverse=True)
28+
if value is None:
29+
_, perm = unique(row * n + col)
30+
index = torch.stack([row[perm], col[perm]], dim=0)
31+
return index, value
32+
33+
uniq, inv = torch.unique(row * n + col, sorted=True, return_inverse=True)
2734

2835
perm = torch.arange(inv.size(0), dtype=inv.dtype, device=inv.device)
29-
perm = inv.new_empty(unique.size(0)).scatter_(0, inv, perm)
36+
perm = inv.new_empty(uniq.size(0)).scatter_(0, inv, perm)
3037
index = torch.stack([row[perm], col[perm]], dim=0)
3138

32-
if value is not None:
33-
op = getattr(torch_scatter, 'scatter_{}'.format(op))
34-
value = op(value, inv, 0, None, perm.size(0), fill_value)
35-
if isinstance(value, tuple):
36-
value = value[0]
39+
op = getattr(torch_scatter, 'scatter_{}'.format(op))
40+
value = op(value, inv, 0, None, perm.size(0), fill_value)
41+
if isinstance(value, tuple):
42+
value = value[0]
3743

3844
return index, value

torch_sparse/utils/__init__.py

Whitespace-only changes.

torch_sparse/utils/unique.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import torch
2+
import numpy as np
3+
4+
if torch.cuda.is_available():
5+
import unique_cuda
6+
7+
8+
def unique(src):
9+
src = src.contiguous().view(-1)
10+
11+
if src.is_cuda:
12+
out, perm = unique_cuda.unique(src)
13+
else:
14+
out, perm = np.unique(src.numpy(), return_index=True)
15+
out, perm = torch.from_numpy(out), torch.from_numpy(perm)
16+
17+
return out, perm

0 commit comments

Comments
 (0)