Skip to content

Commit c5d928e

Browse files
Merge pull request #16 from nicolas-chaulet/unifyballquery
Unifyballquery
2 parents 6e34d96 + 59e189b commit c5d928e

File tree

12 files changed

+153
-146
lines changed

12 files changed

+153
-146
lines changed

cpu/src/bindings.cpp

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,41 +9,41 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
99
m.def("ball_query",
1010
&ball_query,
1111
"compute the radius search of a point cloud using nanoflann"
12-
"-query : a pytorch tensor of size N1 x 3,. used to query the nearest neighbors"
13-
"- support : a pytorch tensor of size N2 x 3. used to build the tree"
14-
"- radius : float number, size of the ball for the radius search."
12+
"- support : a pytorch tensor of size N1 x 3, points where the neighboors are accessed from"
13+
"- query : a pytorch tensor of size N2 x 3, centre of the balls"
14+
"- radius : float number, size of the ball for the radius search."
1515
"- max_num : int number, indicate the maximum of neaghbors allowed(if -1 then all the possible neighbors will be computed). "
16-
" - mode : int number that indicate which format for the neighborhood"
17-
" mode=0 mean a matrix of neighbors(-1 for shadow neighbors)"
16+
"- mode : int number that indicate which format for the neighborhood"
17+
"mode=0 mean a matrix of neighbors(-1 for shadow neighbors)"
1818
"mode=1 means a matrix of edges of size Num_edge x 2"
19-
"return a tensor of size N1 x M where M is either max_num or the maximum number of neighbors found if mode = 0, if mode=1 return a tensor of size Num_edge x 2 and return a tensor containing the squared distance of the neighbors",
20-
"query"_a, "support"_a, "radius"_a, "max_num"_a=-1, "mode"_a=0);
19+
"return a tensor of size N2 x M where M is either max_num or the maximum number of neighbors found if mode = 0, if mode=1 return a tensor of size Num_edge x 2 and return a tensor containing the squared distance of the neighbors",
20+
"support"_a, "querry"_a, "radius"_a, "max_num"_a=-1, "mode"_a=0);
2121

2222
m.def("batch_ball_query",
2323
&batch_ball_query,
2424
"compute the radius search of a point cloud for each batch using nanoflann"
25-
"-query : a pytorch tensor (float) of size N1 x 3,. used to query the nearest neighbors"
26-
"- support : a pytorch tensor(float) of size N2 x 3. used to build the tree"
27-
"- query_batch : a pytorch tensor(long) contains indices of the batch of the query size N1"
28-
"NB : the batch must be sorted"
29-
"- support_batch: a pytorch tensor(long) contains indices of the batch of the support size N2"
25+
"- support : a pytorch tensor of size N1 x 3, points where the neighboors are accessed from"
26+
"- query : a pytorch tensor of size N2 x 3, centre of the balls"
27+
"- support_batch: a pytorch tensor(long) contains indices of the batch of the support size N1"
3028
"NB: the batch must be sorted"
29+
"- query_batch : a pytorch tensor(long) contains indices of the batch of the query size N2"
30+
"NB : the batch must be sorted"
3131
"-radius: float number, size of the ball for the radius search."
3232
"- max_num : int number, indicate the maximum of neaghbors allowed(if -1 then all the possible neighbors wrt the radius will be computed)."
3333
"- mode : int number that indicate which format for the neighborhood"
34-
"mode=0 mean a matrix of neighbors(N2 for shadow neighbors)"
34+
"mode=0 mean a matrix of neighbors(N1 for shadow neighbors)"
3535
"mode=1 means a matrix of edges of size Num_edge x 2"
36-
"return a tensor of size N1 x M where M is either max_num or the maximum number of neighbors found if mode = 0, if mode=1 return a tensor of size Num_edge x 2 and return a tensor containing the squared distance of the neighbors",
37-
"query"_a, "support"_a, "query_batch"_a, "support_batch"_a, "radius"_a, "max_num"_a=-1, "mode"_a=0);
36+
"return a tensor of size N2 x M where M is either max_num or the maximum number of neighbors found if mode = 0, if mode=1 return a tensor of size Num_edge x 2 and return a tensor containing the squared distance of the neighbors",
37+
"support"_a, "querry"_a, "query_batch"_a, "support_batch"_a, "radius"_a, "max_num"_a=-1, "mode"_a=0);
3838
m.def("dense_ball_query", &dense_ball_query,
3939
"compute the radius search of a batch of point cloud using nanoflann"
40-
"-query : a pytorch tensor of size B x N1 x 3,. used to query the nearest neighbors"
41-
"- support : a pytorch tensor of size B x N2 x 3. used to build the tree"
42-
"- radius : float number, size of the ball for the radius search."
40+
"- support : a pytorch tensor of size B x N1 x 3, points where the neighboors are accessed from"
41+
"- query : a pytorch tensor of size B x N2 x 3, centre of the balls"
42+
"- radius : float number, size of the ball for the radius search."
4343
"- max_num : int number, indicate the maximum of neaghbors allowed(if -1 then all the possible neighbors will be computed). "
44-
" - mode : int number that indicate which format for the neighborhood"
45-
" mode=0 mean a matrix of neighbors(-1 for shadow neighbors)"
44+
"- mode : int number that indicate which format for the neighborhood"
45+
"mode=0 mean a matrix of neighbors(-1 for shadow neighbors)"
4646
"mode=1 means a matrix of edges of size Num_edge x 2"
47-
"return a tensor of size N1 x M where M is either max_num or the maximum number of neighbors found if mode = 0, if mode=1 return a tensor of size Num_edge x 2 and return a tensor containing the squared distance of the neighbors",
48-
"query"_a, "support"_a, "radius"_a, "max_num"_a=-1, "mode"_a=0);
47+
"return a tensor of size B x N2 x M where M is either max_num or the maximum number of neighbors found if mode = 0, if mode=1 return a tensor of size Num_edge x 2 and return a tensor containing the squared distance of the neighbors",
48+
"support"_a, "querry"_a, "radius"_a, "max_num"_a=-1, "mode"_a=0);
4949
}

cpu/src/neighbors.cpp

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ int nanoflann_neighbors(vector<scalar_t>& queries,
2323
int i0 = 0;
2424

2525
// Counting vector
26-
int max_count = 1;
26+
size_t max_count = 1;
2727

2828

2929
// Nanoflann related variables
@@ -63,36 +63,30 @@ int nanoflann_neighbors(vector<scalar_t>& queries,
6363
list_matches[i0].reserve(max_count);
6464
std::vector<std::pair<size_t, scalar_t> > ret_matches;
6565

66-
6766
const size_t nMatches = index->radiusSearch(&query_pt[0], search_radius, ret_matches, search_params);
68-
list_matches[i0] = ret_matches;
69-
if((size_t)max_count < nMatches) max_count = nMatches;
67+
if (nMatches == 0)
68+
list_matches[i0] = {std::make_pair(0,-1)};
69+
else
70+
list_matches[i0] = ret_matches;
71+
max_count = max(max_count,nMatches);
7072
i0++;
71-
72-
7373
}
7474
// Reserve the memory
7575
if(max_num > 0) {
7676
max_count = max_num;
7777
}
7878
if(mode == 0){
79-
80-
neighbors_indices.resize(list_matches.size() * max_count);
81-
dists.resize(list_matches.size() * max_count);
82-
79+
neighbors_indices.resize(list_matches.size() * max_count, 0);
80+
dists.resize(list_matches.size() * max_count, -1);
8381
i0 = 0;
84-
8582
int token = 0;
8683
for (auto& inds : list_matches){
8784
token = inds[0].first;
88-
for (int j = 0; j < max_count; j++){
89-
if ((unsigned int)j < inds.size()){
85+
for (size_t j = 0; j < max_count; j++){
86+
if (j < inds.size()){
9087
neighbors_indices[i0 * max_count + j] = inds[j].first;
9188
dists[i0 * max_count + j] = (float) inds[j].second;
92-
93-
9489
}
95-
9690
else {
9791
neighbors_indices[i0 * max_count + j] = token;
9892
dists[i0 * max_count + j] = -1;
@@ -103,9 +97,9 @@ int nanoflann_neighbors(vector<scalar_t>& queries,
10397

10498
}
10599
else if(mode == 1){
106-
int size = 0; // total number of edges
100+
size_t size = 0; // total number of edges
107101
for (auto& inds : list_matches){
108-
if((int)inds.size() <= max_count)
102+
if(inds.size() <= max_count)
109103
size += inds.size();
110104
else
111105
size += max_count;
@@ -115,8 +109,8 @@ int nanoflann_neighbors(vector<scalar_t>& queries,
115109
int i0 = 0; // index of the query points
116110
int u = 0; // curent index of the neighbors_indices
117111
for (auto& inds : list_matches){
118-
for (int j = 0; j < max_count; j++){
119-
if((unsigned int)j < inds.size()){
112+
for (size_t j = 0; j < max_count; j++){
113+
if(j < inds.size()){
120114
neighbors_indices[u] = inds[j].first;
121115
neighbors_indices[u + 1] = i0;
122116
dists[u/2] = (float) inds[j].second;
@@ -125,15 +119,10 @@ int nanoflann_neighbors(vector<scalar_t>& queries,
125119
}
126120
i0++;
127121
}
128-
129-
130122
}
131123
return max_count;
132-
133-
134-
135-
136124
}
125+
137126
template<typename scalar_t>
138127
int batch_nanoflann_neighbors (vector<scalar_t>& queries,
139128
vector<scalar_t>& supports,

cpu/src/torch_nearest_neighbors.cpp

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,39 +7,35 @@
77
#include <iostream>
88

99

10-
std::pair<at::Tensor, at::Tensor> ball_query(at::Tensor query,
11-
at::Tensor support,
10+
std::pair<at::Tensor, at::Tensor> ball_query(at::Tensor support,
11+
at::Tensor query,
1212
float radius, int max_num, int mode){
1313

1414
at::Tensor out;
1515
at::Tensor out_dists;
16-
std::vector<long> neighbors_indices;
16+
std::vector<long> neighbors_indices(query.size(0),0);
17+
std::vector<float> neighbors_dists(query.size(0), -1);
1718

1819
auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
1920
auto options_dist = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU);
2021
int max_count = 0;
2122

22-
std::vector<float> neighbors_dists;
23-
2423
AT_DISPATCH_ALL_TYPES(query.scalar_type(), "radius_search", [&] {
25-
26-
27-
auto data_q = query.DATA_PTR<scalar_t>();
28-
auto data_s = support.DATA_PTR<scalar_t>();
29-
std::vector<scalar_t> queries_stl = std::vector<scalar_t>(data_q,
30-
data_q + query.size(0)*query.size(1));
31-
std::vector<scalar_t> supports_stl = std::vector<scalar_t>(data_s,
32-
data_s + support.size(0)*support.size(1));
33-
34-
max_count = nanoflann_neighbors<scalar_t>(queries_stl,
35-
supports_stl,
36-
neighbors_indices,
37-
neighbors_dists,
38-
radius,
39-
max_num,
40-
mode);
24+
auto data_q = query.DATA_PTR<scalar_t>();
25+
auto data_s = support.DATA_PTR<scalar_t>();
26+
std::vector<scalar_t> queries_stl = std::vector<scalar_t>(data_q,
27+
data_q + query.size(0)*query.size(1));
28+
std::vector<scalar_t> supports_stl = std::vector<scalar_t>(data_s,
29+
data_s + support.size(0)*support.size(1));
30+
31+
max_count = nanoflann_neighbors<scalar_t>(queries_stl,
32+
supports_stl,
33+
neighbors_indices,
34+
neighbors_dists,
35+
radius,
36+
max_num,
37+
mode);
4138
});
42-
4339
auto neighbors_dists_ptr = neighbors_dists.data();
4440
long* neighbors_indices_ptr = neighbors_indices.data();
4541
if(mode == 0){
@@ -65,10 +61,10 @@ at::Tensor degree(at::Tensor row, int64_t num_nodes) {
6561
return zero.scatter_add_(0, row, one);
6662
}
6763

68-
std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor query,
69-
at::Tensor support,
70-
at::Tensor query_batch,
64+
std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor support,
65+
at::Tensor query,
7166
at::Tensor support_batch,
67+
at::Tensor query_batch,
7268
float radius, int max_num, int mode) {
7369
at::Tensor idx;
7470

@@ -92,8 +88,7 @@ std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor query,
9288
AT_DISPATCH_ALL_TYPES(query.scalar_type(), "batch_radius_search", [&] {
9389

9490
std::vector<scalar_t> queries_stl(query.DATA_PTR<scalar_t>(), query.DATA_PTR<scalar_t>() + query.numel());
95-
std::vector<scalar_t> supports_stl(support.DATA_PTR<scalar_t>(), support.DATA_PTR<scalar_t>() + support.numel());
96-
91+
std::vector<scalar_t> supports_stl(support.DATA_PTR<scalar_t>(), support.DATA_PTR<scalar_t>() + support.numel());
9792

9893
max_count = batch_nanoflann_neighbors<scalar_t>(queries_stl,
9994
supports_stl,
@@ -114,7 +109,6 @@ std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor query,
114109
dist = torch::from_blob(neighbors_dists_ptr,
115110
{query.size(0), max_count},
116111
options=options_dist);
117-
118112
}
119113
else if(mode ==1){
120114
idx = torch::from_blob(neighbors_indices_ptr, {(int)neighbors_indices.size()/2, 2}, options=options);
@@ -123,12 +117,11 @@ std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor query,
123117
options=options_dist);
124118
}
125119
return std::make_pair(idx.clone(), dist.clone());
126-
127120
}
128121

129122

130-
std::pair<at::Tensor, at::Tensor> dense_ball_query(at::Tensor query,
131-
at::Tensor support,
123+
std::pair<at::Tensor, at::Tensor> dense_ball_query(at::Tensor support,
124+
at::Tensor query,
132125
float radius, int max_num, int mode){
133126

134127
int b = query.size(0);

cuda/include/compat.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#ifdef VERSION_GE_1_3
2+
#define DATA_PTR data_ptr
3+
#else
4+
#define DATA_PTR data
5+
#endif

cuda/include/utils.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,22 @@
44

55
#define CHECK_CUDA(x) \
66
do { \
7-
AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor"); \
7+
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor"); \
88
} while (0)
99

1010
#define CHECK_CONTIGUOUS(x) \
1111
do { \
12-
AT_CHECK(x.is_contiguous(), #x " must be a contiguous tensor"); \
12+
TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor"); \
1313
} while (0)
1414

1515
#define CHECK_IS_INT(x) \
1616
do { \
17-
AT_CHECK(x.scalar_type() == at::ScalarType::Int, \
17+
TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \
1818
#x " must be an int tensor"); \
1919
} while (0)
2020

2121
#define CHECK_IS_FLOAT(x) \
2222
do { \
23-
AT_CHECK(x.scalar_type() == at::ScalarType::Float, \
23+
TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, \
2424
#x " must be a float tensor"); \
2525
} while (0)

cuda/src/ball_query.cpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
#include "ball_query.h"
22
#include "utils.h"
3+
#include "compat.h"
34

45
void query_ball_point_kernel_dense_wrapper(int b, int n, int m, float radius,
56
int nsample, const float *new_xyz,
67
const float *xyz, int *idx);
78

89
void query_ball_point_kernel_partial_wrapper(long batch_size,
910
int size_x,
10-
int size_y,
11-
float radius,
11+
int size_y,
12+
float radius,
1213
int nsample,
1314
const float *x,
1415
const float *y,
@@ -33,10 +34,10 @@ at::Tensor ball_query_dense(at::Tensor new_xyz, at::Tensor xyz, const float radi
3334

3435
if (new_xyz.type().is_cuda()) {
3536
query_ball_point_kernel_dense_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1),
36-
radius, nsample, new_xyz.data<float>(),
37-
xyz.data<float>(), idx.data<int>());
37+
radius, nsample, new_xyz.DATA_PTR<float>(),
38+
xyz.DATA_PTR<float>(), idx.DATA_PTR<int>());
3839
} else {
39-
AT_CHECK(false, "CPU not supported");
40+
TORCH_CHECK(false, "CPU not supported");
4041
}
4142

4243
return idx;
@@ -68,13 +69,13 @@ std::pair<at::Tensor, at::Tensor> ball_query_partial_dense(at::Tensor x,
6869

6970
at::Tensor idx = torch::full({y.size(0), nsample}, x.size(0),
7071
at::device(y.device()).dtype(at::ScalarType::Long));
71-
72+
7273
at::Tensor dist = torch::full({y.size(0), nsample}, -1,
7374
at::device(y.device()).dtype(at::ScalarType::Float));
7475

7576
cudaSetDevice(x.get_device());
7677
auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
77-
cudaMemcpy(batch_sizes, batch_x[-1].data<int64_t>(), sizeof(int64_t),
78+
cudaMemcpy(batch_sizes, batch_x[-1].DATA_PTR<int64_t>(), sizeof(int64_t),
7879
cudaMemcpyDeviceToHost);
7980
auto batch_size = batch_sizes[0] + 1;
8081

@@ -88,14 +89,14 @@ std::pair<at::Tensor, at::Tensor> ball_query_partial_dense(at::Tensor x,
8889
x.size(0),
8990
y.size(0),
9091
radius, nsample,
91-
x.data<float>(),
92-
y.data<float>(),
93-
batch_x.data<long>(),
94-
batch_y.data<long>(),
95-
idx.data<long>(),
96-
dist.data<float>());
92+
x.DATA_PTR<float>(),
93+
y.DATA_PTR<float>(),
94+
batch_x.DATA_PTR<long>(),
95+
batch_y.DATA_PTR<long>(),
96+
idx.DATA_PTR<long>(),
97+
dist.DATA_PTR<float>());
9798
} else {
98-
AT_CHECK(false, "CPU not supported");
99+
TORCH_CHECK(false, "CPU not supported");
99100
}
100101

101102
return std::make_pair(idx, dist);

cuda/src/group_points.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "group_points.h"
22
#include "utils.h"
3+
#include "compat.h"
34

45
void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample,
56
const float *points, const int *idx,
@@ -25,10 +26,10 @@ at::Tensor group_points(at::Tensor points, at::Tensor idx) {
2526

2627
if (points.type().is_cuda()) {
2728
group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2),
28-
idx.size(1), idx.size(2), points.data<float>(),
29-
idx.data<int>(), output.data<float>());
29+
idx.size(1), idx.size(2), points.DATA_PTR<float>(),
30+
idx.DATA_PTR<int>(), output.DATA_PTR<float>());
3031
} else {
31-
AT_CHECK(false, "CPU not supported");
32+
TORCH_CHECK(false, "CPU not supported");
3233
}
3334

3435
return output;
@@ -51,9 +52,9 @@ at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) {
5152
if (grad_out.type().is_cuda()) {
5253
group_points_grad_kernel_wrapper(
5354
grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2),
54-
grad_out.data<float>(), idx.data<int>(), output.data<float>());
55+
grad_out.DATA_PTR<float>(), idx.DATA_PTR<int>(), output.DATA_PTR<float>());
5556
} else {
56-
AT_CHECK(false, "CPU not supported");
57+
TORCH_CHECK(false, "CPU not supported");
5758
}
5859

5960
return output;

0 commit comments

Comments
 (0)