Skip to content

Commit 2884bcc

Browse files
Merge pull request #11 from humanpose1/query_ball_cpu
Query ball cpu
2 parents fdc285a + 153ea08 commit 2884bcc

File tree

8 files changed

+128
-33
lines changed

8 files changed

+128
-33
lines changed

cpu/include/ball_query.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,7 @@ std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor query,
99
at::Tensor query_batch,
1010
at::Tensor support_batch,
1111
float radius, int max_num, int mode);
12+
13+
std::pair<at::Tensor, at::Tensor> dense_ball_query(at::Tensor query,
14+
at::Tensor support,
15+
float radius, int max_num, int mode);

cpu/include/cloud.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ struct PointCloud
5757
pts = temp;
5858
}
5959
void set_batch(std::vector<scalar_t> new_pts, int begin, int size){
60+
6061
std::vector<PointXYZ> temp(size);
6162
for(int i=0; i < size; i++){
6263
PointXYZ point;

cpu/src/bindings.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
3535
"mode=1 means a matrix of edges of size Num_edge x 2"
3636
"return a tensor of size N1 x M where M is either max_num or the maximum number of neighbors found if mode = 0, if mode=1 return a tensor of size Num_edge x 2 and return a tensor containing the squared distance of the neighbors",
3737
"query"_a, "support"_a, "query_batch"_a, "support_batch"_a, "radius"_a, "max_num"_a=-1, "mode"_a=0);
38+
m.def("dense_ball_query", &dense_ball_query,
39+
"compute the radius search of a batch of point cloud using nanoflann"
40+
"-query : a pytorch tensor of size B x N1 x 3,. used to query the nearest neighbors"
41+
"- support : a pytorch tensor of size B x N2 x 3. used to build the tree"
42+
"- radius : float number, size of the ball for the radius search."
43+
"- max_num : int number, indicate the maximum of neaghbors allowed(if -1 then all the possible neighbors will be computed). "
44+
" - mode : int number that indicate which format for the neighborhood"
45+
" mode=0 mean a matrix of neighbors(-1 for shadow neighbors)"
46+
"mode=1 means a matrix of edges of size Num_edge x 2"
47+
"return a tensor of size N1 x M where M is either max_num or the maximum number of neighbors found if mode = 0, if mode=1 return a tensor of size Num_edge x 2 and return a tensor containing the squared distance of the neighbors",
48+
"query"_a, "support"_a, "radius"_a, "max_num"_a=-1, "mode"_a=0);
3849
}

cpu/src/neighbors.cpp

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,20 @@ int nanoflann_neighbors(vector<scalar_t>& queries,
8282

8383
i0 = 0;
8484

85+
int token = 0;
8586
for (auto& inds : list_matches){
87+
token = inds[0].first;
8688
for (int j = 0; j < max_count; j++){
8789
if (j < inds.size()){
8890
neighbors_indices[i0 * max_count + j] = inds[j].first;
8991
dists[i0 * max_count + j] = (float) inds[j].second;
92+
93+
9094
}
9195

9296
else {
93-
neighbors_indices[i0 * max_count + j] = -1;
94-
dists[i0 * max_count + j] = radius * radius;
97+
neighbors_indices[i0 * max_count + j] = token;
98+
dists[i0 * max_count + j] = -1;
9599
}
96100
}
97101
i0++;
@@ -186,24 +190,30 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
186190
search_params.sorted = true;
187191
for (auto& p0 : query_pcd.pts){
188192
// Check if we changed batch
189-
190-
if (i0 == sum_qb + q_batches[b]){
193+
if (i0 == sum_qb + q_batches[b] && b < s_batches.size()){
191194
sum_qb += q_batches[b];
192195
sum_sb += s_batches[b];
196+
193197
b++;
194198

195199
// Change the points
196200
current_cloud.pts.clear();
197201
current_cloud.set_batch(supports, sum_sb, s_batches[b]);
198202
// Build KDTree of the current element of the batch
199203
delete index;
204+
200205
index = new my_kd_tree_t(3, current_cloud, tree_params);
201206
index->buildIndex();
202207
}
203208
// Initial guess of neighbors size
209+
210+
204211
all_inds_dists[i0].reserve(max_count);
205212
// Find neighbors
213+
//std::cerr << p0.x << p0.y << p0.z<<std::endl;
206214
scalar_t query_pt[3] = { p0.x, p0.y, p0.z};
215+
216+
207217
size_t nMatches = index->radiusSearch(query_pt, r2, all_inds_dists[i0], search_params);
208218
// Update max count
209219

@@ -217,8 +227,10 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
217227
max_count = max_num;
218228
}
219229
// Reserve the memory
230+
220231
if(mode == 0){
221232
neighbors_indices.resize(query_pcd.pts.size() * max_count);
233+
222234
dists.resize(query_pcd.pts.size() * max_count);
223235
i0 = 0;
224236
sum_sb = 0;
@@ -227,6 +239,7 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
227239

228240
for (auto& inds_dists : all_inds_dists){// Check if we changed batch
229241

242+
230243
if (i0 == sum_qb + q_batches[b]){
231244
sum_qb += q_batches[b];
232245
sum_sb += s_batches[b];
@@ -239,8 +252,8 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
239252
dists[i0 * max_count + j] = (float) inds_dists[j].second;
240253
}
241254
else {
242-
neighbors_indices[i0 * max_count + j] = supports.size();
243-
dists[i0 * max_count + j] = radius * radius;
255+
neighbors_indices[i0 * max_count + j] = supports.size()/3;
256+
dists[i0 * max_count + j] = -1;
244257
}
245258

246259
}

cpu/src/torch_nearest_neighbors.cpp

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,22 @@ std::pair<at::Tensor, at::Tensor> ball_query(at::Tensor query,
6161

6262
void cumsum(const vector<long>& batch, vector<long>& res){
6363

64-
res.resize(batch[batch.size()-1]-batch[0]+1, 0);
64+
res.resize(batch[batch.size()-1]-batch[0]+2, 0);
6565
long ind = batch[0];
6666
long incr = 1;
67-
for(int i=1; i < batch.size(); i++){
68-
69-
if(batch[i] == ind)
70-
incr++;
71-
else{
72-
res[ind-batch[0]] = incr;
73-
incr =1;
74-
ind = batch[i];
67+
if(res.size() > 1){
68+
for(int i=1; i < batch.size(); i++){
69+
if(batch[i] == ind)
70+
incr++;
71+
else{
72+
res[ind-batch[0]+1] = incr;
73+
incr =1;
74+
ind = batch[i];
75+
}
7576
}
77+
7678
}
77-
res[ind-batch[0]] = incr;
79+
res[ind-batch[0]+1] = incr;
7880
}
7981

8082
std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor query,
@@ -89,9 +91,11 @@ std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor query,
8991
std::vector<long> query_batch_stl = std::vector<long>(data_qb, data_qb+query_batch.size(0));
9092
std::vector<long> cumsum_query_batch_stl;
9193
cumsum(query_batch_stl, cumsum_query_batch_stl);
94+
9295
std::vector<long> support_batch_stl = std::vector<long>(data_sb, data_sb+support_batch.size(0));
9396
std::vector<long> cumsum_support_batch_stl;
9497
cumsum(support_batch_stl, cumsum_support_batch_stl);
98+
9599
std::vector<long> neighbors_indices;
96100

97101
auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
@@ -107,6 +111,7 @@ std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor query,
107111
std::vector<scalar_t> supports_stl = std::vector<scalar_t>(data_s,
108112
data_s + support.size(0)*support.size(1));
109113

114+
110115
max_count = batch_nanoflann_neighbors<scalar_t>(queries_stl,
111116
supports_stl,
112117
cumsum_query_batch_stl,
@@ -117,6 +122,7 @@ std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor query,
117122
max_num,
118123
mode);
119124
});
125+
120126
long* neighbors_indices_ptr = neighbors_indices.data();
121127
auto neighbors_dists_ptr = neighbors_dists.data();
122128

@@ -135,3 +141,23 @@ std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor query,
135141
}
136142
return std::make_pair(out.clone(), out_dists.clone());
137143
}
144+
145+
146+
std::pair<at::Tensor, at::Tensor> dense_ball_query(at::Tensor query,
147+
at::Tensor support,
148+
float radius, int max_num, int mode){
149+
150+
int b = query.size(0);
151+
vector<at::Tensor> batch_idx;
152+
vector<at::Tensor> batch_dist;
153+
for (int i=0; i < b; i++){
154+
155+
auto out_pair = ball_query(query[i], support[i], radius, max_num, mode);
156+
batch_idx.push_back(out_pair.first);
157+
batch_dist.push_back(out_pair.second);
158+
}
159+
auto out_idx = torch::stack(batch_idx);
160+
auto out_dist = torch::stack(batch_dist);
161+
return std::make_pair(out_idx, out_dist);
162+
163+
}

test/test_ballquerry.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@ def test_simple_cpu(self):
1919

2020
def test_cpu_gpu_equality(self):
2121
a = torch.randn(5, 1000, 3)
22-
npt.assert_array_equal(ball_query_dense(0.1, 17, a, a).detach().numpy(),
23-
ball_query_dense(0.1, 17, a.cuda(), a.cuda()).cpu().detach().numpy())
22+
res_cpu = ball_query_dense(0.1, 17, a, a).detach().numpy()
23+
res_cuda = ball_query_dense(0.1, 17, a.cuda(), a.cuda()).cpu().detach().numpy()
24+
for i in range(a.shape[0]):
25+
for j in range(a.shape[1]):
26+
# Because it is not necessary the same order
27+
assert set(res_cpu[i][j]) == set(res_cuda[i][j])
2428

2529

2630
if __name__ == "__main__":

test/test_ballquerry_partial.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22
import torch
33
from torch_points import ball_query
4+
from torch_points.points_cpu import ball_query as cpu_ball_query
45
from torch_cluster import radius_cuda
56
import numpy.testing as npt
67
import numpy as np
@@ -11,7 +12,7 @@ def test_simple_gpu(self):
1112
y = torch.tensor([[0, 0, 0]]).to(torch.float).cuda()
1213
batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long().cuda()
1314
batch_y = torch.from_numpy(np.asarray([0])).long().cuda()
14-
15+
1516
batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long().cuda()
1617
batch_y = torch.from_numpy(np.asarray([0])).long().cuda()
1718

@@ -26,5 +27,43 @@ def test_simple_gpu(self):
2627
npt.assert_array_almost_equal(idx, idx_answer)
2728
npt.assert_array_almost_equal(dist2, dist2_answer)
2829

30+
31+
def test_simple_cpu(self):
32+
x = torch.tensor([[10, 0, 0], [0.1, 0, 0], [10, 0, 0], [0.1, 0, 0]]).to(torch.float)
33+
y = torch.tensor([[0, 0, 0]]).to(torch.float)
34+
batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long()
35+
batch_y = torch.from_numpy(np.asarray([0])).long()
36+
37+
batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long()
38+
batch_y = torch.from_numpy(np.asarray([0])).long()
39+
40+
idx, dist2 = ball_query(1., 2, x, y, batch_x, batch_y, mode="PARTIAL_DENSE")
41+
42+
idx = idx.detach().cpu().numpy()
43+
dist2 = dist2.detach().cpu().numpy()
44+
45+
idx_answer = np.asarray([[1, 1], [0, 1], [1, 1], [1, 1]])
46+
dist2_answer = np.asarray([[-1, -1], [0.01, -1], [-1, -1], [-1, -1]]).astype(np.float32)
47+
48+
npt.assert_array_almost_equal(idx, idx_answer)
49+
npt.assert_array_almost_equal(dist2, dist2_answer)
50+
51+
def test_random_cpu(self):
52+
a = torch.randn(1000, 3).to(torch.float)
53+
b = torch.randn(1500, 3).to(torch.float)
54+
batch_a = torch.randint(1, (1000,)).sort(0)[0].long()
55+
batch_b = torch.randint(1, (1500,)).sort(0)[0].long()
56+
idx, dist2 = ball_query(1.0, 12, a, b, batch_a, batch_b, mode="PARTIAL_DENSE")
57+
idx, dist2 = ball_query(1.0, 12, b, a, batch_b, batch_a, mode="PARTIAL_DENSE")
58+
idx = idx.detach().cpu().numpy()
59+
dist2 = dist2.detach().cpu().numpy()
60+
idx2, _ = cpu_ball_query(a, b, 1.0, 12)
61+
print(idx[5], print(idx2[5]))
62+
63+
64+
65+
66+
67+
2968
if __name__ == "__main__":
3069
unittest.main()

torch_points/torchpoints.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -251,17 +251,10 @@ def forward(ctx, radius, nsample, xyz, new_xyz, batch_xyz=None, batch_new_xyz=No
251251
if new_xyz.is_cuda:
252252
return tpcuda.ball_query_dense(new_xyz, xyz, radius, nsample)
253253
else:
254-
b = xyz.size(0)
255-
npoints = new_xyz.size(1)
256-
n = xyz.size(1)
257-
batch_new_xyz = torch.arange(0, b, dtype=torch.long).repeat(npoints, 1).T.reshape(-1)
258-
batch_xyz = torch.arange(0, b, dtype=torch.long).repeat(n, 1).T.reshape(-1)
259-
ind, dist = tpcpu.batch_ball_query(new_xyz.view(-1, 3),
260-
xyz.view(-1, 3),
261-
batch_new_xyz,
262-
batch_xyz,
263-
radius, nsample)
264-
return ind.view(b, npoints, nsample)
254+
ind, dist = tpcpu.dense_ball_query(new_xyz,
255+
xyz,
256+
radius, nsample, mode=0)
257+
return ind
265258

266259
@staticmethod
267260
def backward(ctx, a=None):
@@ -299,7 +292,11 @@ def forward(ctx, radius, nsample, x, y, batch_x, batch_y):
299292
batch_y,
300293
radius, nsample)
301294
else:
302-
raise NotImplementedError
295+
ind, dist = tpcpu.batch_ball_query(x, y,
296+
batch_x,
297+
batch_y,
298+
radius, nsample, mode=0)
299+
return ind, dist
303300

304301
@staticmethod
305302
def backward(ctx, a=None):
@@ -319,9 +316,9 @@ def ball_query_partial_dense(radius, nsample, x, y, batch_x, batch_y):
319316
y : torch.Tensor
320317
(N, npoint, 3) centers of the ball query
321318
batch_x : torch.Tensor
322-
(M, ) Contains indexes to indicate within batch it belongs to.
319+
(M, ) Contains indexes to indicate within batch it belongs to.
323320
batch_y : torch.Tensor
324-
(N, ) Contains indexes to indicate within batch it belongs to
321+
(N, ) Contains indexes to indicate within batch it belongs to
325322
326323
Returns
327324
-------

0 commit comments

Comments
 (0)