Skip to content

Commit 01407c7

Browse files
Merge pull request #7 from nicolas-chaulet/ball_query_2
Ball query 2
2 parents 7138bb8 + fdf87ec commit 01407c7

File tree

11 files changed

+295
-56
lines changed

11 files changed

+295
-56
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@ import torch
1717
import torch_points.points_cuda
1818
```
1919

20+
## Build and test
21+
```
22+
python setup.py build_ext --inplace
23+
python -m unittest
24+
```
25+
2026
## Projects using those kernels.
2127

2228
[```Pytorch Point Cloud Benchmark```](https://github.com/nicolas-chaulet/deeppointcloud-benchmarks) by

cuda/include/ball_query.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
#pragma once
22
#include <torch/extension.h>
33

4-
at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius,
5-
const int nsample);
4+
at::Tensor ball_query_dense(at::Tensor new_xyz, at::Tensor xyz, const float radius,
5+
const int nsample);
6+
7+
std::pair<at::Tensor, at::Tensor> ball_query_partial_dense(at::Tensor x,
8+
at::Tensor y,
9+
at::Tensor batch_x,
10+
at::Tensor batch_y,
11+
const float radius,
12+
const int nsample);
13+
14+
at::Tensor degree(at::Tensor row, int64_t num_nodes);

cuda/include/cuda_utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
#include <vector>
1212

13-
#define TOTAL_THREADS 512
13+
#define TOTAL_THREADS 1024
1414

1515
inline int opt_n_threads(int work_size) {
1616
const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);

cuda/src/ball_query.cpp

Lines changed: 79 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
#include "ball_query.h"
22
#include "utils.h"
33

4-
void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
5-
int nsample, const float *new_xyz,
6-
const float *xyz, int *idx);
4+
void query_ball_point_kernel_dense_wrapper(int b, int n, int m, float radius,
5+
int nsample, const float *new_xyz,
6+
const float *xyz, int *idx);
77

8-
at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius,
9-
const int nsample) {
8+
void query_ball_point_kernel_partial_wrapper(long batch_size,
9+
int size_x,
10+
int size_y,
11+
float radius,
12+
int nsample,
13+
const float *x,
14+
const float *y,
15+
const long *batch_x,
16+
const long *batch_y,
17+
long *idx_out,
18+
float *dist_out);
19+
20+
at::Tensor ball_query_dense(at::Tensor new_xyz, at::Tensor xyz, const float radius,
21+
const int nsample) {
1022
CHECK_CONTIGUOUS(new_xyz);
1123
CHECK_CONTIGUOUS(xyz);
1224
CHECK_IS_FLOAT(new_xyz);
@@ -21,12 +33,71 @@ at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius,
2133
at::device(new_xyz.device()).dtype(at::ScalarType::Int));
2234

2335
if (new_xyz.type().is_cuda()) {
24-
query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1),
25-
radius, nsample, new_xyz.data<float>(),
26-
xyz.data<float>(), idx.data<int>());
36+
query_ball_point_kernel_dense_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1),
37+
radius, nsample, new_xyz.data<float>(),
38+
xyz.data<float>(), idx.data<int>());
2739
} else {
2840
AT_CHECK(false, "CPU not supported");
2941
}
3042

3143
return idx;
3244
}
45+
46+
at::Tensor degree(at::Tensor row, int64_t num_nodes) {
47+
auto zero = at::zeros(num_nodes, row.options());
48+
auto one = at::ones(row.size(0), row.options());
49+
return zero.scatter_add_(0, row, one);
50+
}
51+
52+
std::pair<at::Tensor, at::Tensor> ball_query_partial_dense(at::Tensor x,
53+
at::Tensor y,
54+
at::Tensor batch_x,
55+
at::Tensor batch_y,
56+
const float radius,
57+
const int nsample) {
58+
CHECK_CONTIGUOUS(x);
59+
CHECK_CONTIGUOUS(y);
60+
CHECK_IS_FLOAT(x);
61+
CHECK_IS_FLOAT(y);
62+
63+
if (x.type().is_cuda()) {
64+
CHECK_CUDA(x);
65+
CHECK_CUDA(y);
66+
CHECK_CUDA(batch_x);
67+
CHECK_CUDA(batch_y);
68+
}
69+
70+
at::Tensor idx = torch::full({x.size(0), nsample}, y.size(0),
71+
at::device(x.device()).dtype(at::ScalarType::Long));
72+
73+
at::Tensor dist = torch::full({x.size(0), nsample}, -1,
74+
at::device(x.device()).dtype(at::ScalarType::Float));
75+
76+
cudaSetDevice(x.get_device());
77+
auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
78+
cudaMemcpy(batch_sizes, batch_x[-1].data<int64_t>(), sizeof(int64_t),
79+
cudaMemcpyDeviceToHost);
80+
auto batch_size = batch_sizes[0] + 1;
81+
82+
batch_x = degree(batch_x, batch_size);
83+
batch_x = at::cat({at::zeros(1, batch_x.options()), batch_x.cumsum(0)}, 0);
84+
batch_y = degree(batch_y, batch_size);
85+
batch_y = at::cat({at::zeros(1, batch_y.options()), batch_y.cumsum(0)}, 0);
86+
87+
if (x.type().is_cuda()) {
88+
query_ball_point_kernel_partial_wrapper(batch_size,
89+
x.size(0),
90+
y.size(0),
91+
radius, nsample,
92+
x.data<float>(),
93+
y.data<float>(),
94+
batch_x.data<long>(),
95+
batch_y.data<long>(),
96+
idx.data<long>(),
97+
dist.data<float>());
98+
} else {
99+
AT_CHECK(false, "CPU not supported");
100+
}
101+
102+
return std::make_pair(idx, dist);
103+
}

cuda/src/ball_query_gpu.cu

Lines changed: 74 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@
66

77
// input: new_xyz(b, m, 3) xyz(b, n, 3)
88
// output: idx(b, m, nsample)
9-
__global__ void query_ball_point_kernel(int b, int n, int m, float radius,
10-
int nsample,
11-
const float *__restrict__ new_xyz,
12-
const float *__restrict__ xyz,
13-
int *__restrict__ idx) {
9+
__global__ void query_ball_point_kernel_dense(int b, int n, int m, float radius,
10+
int nsample,
11+
const float *__restrict__ new_xyz,
12+
const float *__restrict__ xyz,
13+
int *__restrict__ idx_out) {
14+
1415
int batch_index = blockIdx.x;
1516
xyz += batch_index * n * 3;
1617
new_xyz += batch_index * m * 3;
17-
idx += m * nsample * batch_index;
18+
idx_out += m * nsample * batch_index;
1819

1920
int index = threadIdx.x;
2021
int stride = blockDim.x;
@@ -33,22 +34,83 @@ __global__ void query_ball_point_kernel(int b, int n, int m, float radius,
3334
if (d2 < radius2) {
3435
if (cnt == 0) {
3536
for (int l = 0; l < nsample; ++l) {
36-
idx[j * nsample + l] = k;
37+
idx_out[j * nsample + l] = k;
3738
}
3839
}
39-
idx[j * nsample + cnt] = k;
40+
idx_out[j * nsample + cnt] = k;
4041
++cnt;
4142
}
4243
}
4344
}
4445
}
4546

46-
void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
47-
int nsample, const float *new_xyz,
48-
const float *xyz, int *idx) {
47+
__global__ void query_ball_point_kernel_partial_dense(int size_x,
48+
int size_y,
49+
float radius,
50+
int nsample,
51+
const float *__restrict__ x,
52+
const float *__restrict__ y,
53+
const long *__restrict__ batch_x,
54+
const long *__restrict__ batch_y,
55+
int64_t *__restrict__ idx_out,
56+
float * __restrict__ dist_out) {
57+
58+
// taken from https://github.com/rusty1s/pytorch_cluster/blob/master/cuda/radius_kernel.cu
59+
const ptrdiff_t batch_idx = blockIdx.x;
60+
const ptrdiff_t idx = threadIdx.x;
61+
62+
const ptrdiff_t start_idx_x = batch_x[batch_idx];
63+
const ptrdiff_t end_idx_x = batch_x[batch_idx + 1];
64+
65+
const ptrdiff_t start_idx_y = batch_y[batch_idx];
66+
const ptrdiff_t end_idx_y = batch_y[batch_idx + 1];
67+
float radius2 = radius * radius;
68+
69+
for (ptrdiff_t n_x = start_idx_x + idx; n_x < end_idx_x; n_x += TOTAL_THREADS) {
70+
int64_t count = 0;
71+
for (ptrdiff_t n_y = start_idx_y; n_y < end_idx_y; n_y++) {
72+
float dist = 0;
73+
for (ptrdiff_t d = 0; d < 3; d++) {
74+
dist += (x[n_x * 3 + d] - y[n_y * 3 + d]) *
75+
(x[n_x * 3 + d] - y[n_y * 3 + d]);
76+
}
77+
if(dist <= radius2){
78+
idx_out[n_x * nsample + count] = n_y;
79+
dist_out[n_x * nsample + count] = dist;
80+
count++;
81+
}
82+
if(count >= nsample){
83+
break;
84+
}
85+
}
86+
}
87+
}
88+
89+
void query_ball_point_kernel_dense_wrapper(int b, int n, int m, float radius,
90+
int nsample, const float *new_xyz,
91+
const float *xyz, int *idx) {
4992
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
50-
query_ball_point_kernel<<<b, opt_n_threads(m), 0, stream>>>(
93+
query_ball_point_kernel_dense<<<b, opt_n_threads(m), 0, stream>>>(
5194
b, n, m, radius, nsample, new_xyz, xyz, idx);
5295

5396
CUDA_CHECK_ERRORS();
5497
}
98+
99+
void query_ball_point_kernel_partial_wrapper(long batch_size,
100+
int size_x,
101+
int size_y,
102+
float radius,
103+
int nsample,
104+
const float *x,
105+
const float *y,
106+
const long *batch_x,
107+
const long *batch_y,
108+
int64_t *idx_out,
109+
float *dist_out) {
110+
111+
query_ball_point_kernel_partial_dense<<<batch_size, TOTAL_THREADS>>>(
112+
size_x, size_y, radius, nsample, x, y,
113+
batch_x, batch_y, idx_out, dist_out);
114+
115+
CUDA_CHECK_ERRORS();
116+
}

cuda/src/bindings.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1212
m.def("three_interpolate", &three_interpolate);
1313
m.def("three_interpolate_grad", &three_interpolate_grad);
1414

15-
m.def("ball_query", &ball_query);
15+
m.def("ball_query_dense", &ball_query_dense);
16+
m.def("ball_query_partial_dense", &ball_query_partial_dense);
1617

1718
m.def("group_points", &group_points);
1819
m.def("group_points_grad", &group_points_grad);

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
setup(
3737
name="torch_points",
38-
version="0.1.2",
38+
version="0.1.3",
3939
author="Nicolas Chaulet",
4040
packages=find_packages(),
4141
install_requires=[],

test/test_ballquerry.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import unittest
22
import torch
3-
from torch_points import ball_query
3+
from torch_points import ball_query_dense
44
import numpy.testing as npt
55
import numpy as np
66

@@ -10,19 +10,17 @@ def test_simple_gpu(self):
1010
a = torch.tensor([[[0, 0, 0], [1, 0, 0], [2, 0, 0]]]).to(torch.float).cuda()
1111
b = torch.tensor([[[0, 0, 0]]]).to(torch.float).cuda()
1212

13-
npt.assert_array_equal(ball_query(1, 2, a, b).detach().cpu().numpy(), np.array([[[0, 0]]]))
13+
npt.assert_array_equal(ball_query_dense(1, 2, a, b).detach().cpu().numpy(), np.array([[[0, 0]]]))
1414

1515
def test_simple_cpu(self):
1616
a = torch.tensor([[[0, 0, 0], [1, 0, 0], [2, 0, 0]]]).to(torch.float)
1717
b = torch.tensor([[[0, 0, 0]]]).to(torch.float)
18-
npt.assert_array_equal(ball_query(1, 2, a, b).detach().numpy(), np.array([[[0, 0]]]))
18+
npt.assert_array_equal(ball_query_dense(1, 2, a, b).detach().numpy(), np.array([[[0, 0]]]))
1919

2020
def test_cpu_gpu_equality(self):
2121
a = torch.randn(5, 1000, 3)
22-
npt.assert_array_equal(ball_query(0.1, 17, a, a).detach().numpy(),
23-
ball_query(0.1, 17, a.cuda(), a.cuda()).detach().numpy())
24-
25-
22+
npt.assert_array_equal(ball_query_dense(0.1, 17, a, a).detach().numpy(),
23+
ball_query_dense(0.1, 17, a.cuda(), a.cuda()).cpu().detach().numpy())
2624

2725

2826
if __name__ == "__main__":

test/test_ballquerry_partial.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import unittest
2+
import torch
3+
from torch_points import ball_query
4+
from torch_cluster import radius_cuda
5+
import numpy.testing as npt
6+
import numpy as np
7+
8+
class TestBallPartial(unittest.TestCase):
9+
def test_simple_gpu(self):
10+
x = torch.tensor([[10, 0, 0], [0.1, 0, 0], [10, 0, 0], [0.1, 0, 0]]).to(torch.float).cuda()
11+
y = torch.tensor([[0, 0, 0]]).to(torch.float).cuda()
12+
batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long().cuda()
13+
batch_y = torch.from_numpy(np.asarray([0])).long().cuda()
14+
15+
batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long().cuda()
16+
batch_y = torch.from_numpy(np.asarray([0])).long().cuda()
17+
18+
idx, dist2 = ball_query(1., 2, x, y, batch_x, batch_y, mode="PARTIAL_DENSE")
19+
20+
idx = idx.detach().cpu().numpy()
21+
dist2 = dist2.detach().cpu().numpy()
22+
23+
idx_answer = np.asarray([[1, 1], [0, 1], [1, 1], [1, 1]])
24+
dist2_answer = np.asarray([[-1, -1], [0.01, -1], [-1, -1], [-1, -1]]).astype(np.float32)
25+
26+
npt.assert_array_almost_equal(idx, idx_answer)
27+
npt.assert_array_almost_equal(dist2, dist2_answer)
28+
29+
if __name__ == "__main__":
30+
unittest.main()

0 commit comments

Comments
 (0)