Skip to content

Commit b1cde16

Browse files
Merge pull request #14 from humanpose1/debug_cpu
debug of cpu ball query
2 parents 031a225 + 8ca6838 commit b1cde16

File tree

3 files changed

+81
-90
lines changed

3 files changed

+81
-90
lines changed

cpu/src/neighbors.cpp

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ int nanoflann_neighbors(vector<scalar_t>& queries,
8686
for (auto& inds : list_matches){
8787
token = inds[0].first;
8888
for (int j = 0; j < max_count; j++){
89-
if (j < inds.size()){
89+
if ((unsigned int)j < inds.size()){
9090
neighbors_indices[i0 * max_count + j] = inds[j].first;
9191
dists[i0 * max_count + j] = (float) inds[j].second;
9292

@@ -116,7 +116,7 @@ int nanoflann_neighbors(vector<scalar_t>& queries,
116116
int u = 0; // curent index of the neighbors_indices
117117
for (auto& inds : list_matches){
118118
for (int j = 0; j < max_count; j++){
119-
if(j < inds.size()){
119+
if((unsigned int)j < inds.size()){
120120
neighbors_indices[u] = inds[j].first;
121121
neighbors_indices[u + 1] = i0;
122122
dists[u/2] = (float) inds[j].second;
@@ -158,9 +158,8 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
158158

159159

160160
// batch index
161-
long b = 0;
162-
long sum_qb = 0;
163-
long sum_sb = 0;
161+
int b = 0;
162+
164163

165164
// Nanoflann related variables
166165
// ***************************
@@ -180,7 +179,7 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
180179
// Pointer to trees
181180
my_kd_tree_t* index;
182181
// Build KDTree for the first batch element
183-
current_cloud.set_batch(supports, sum_sb, s_batches[b]);
182+
current_cloud.set_batch(supports, s_batches[b], s_batches[b+1]);
184183
index = new my_kd_tree_t(3, current_cloud, tree_params);
185184
index->buildIndex();
186185
// Search neigbors indices
@@ -190,21 +189,22 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
190189
search_params.sorted = true;
191190
for (auto& p0 : query_pcd.pts){
192191
// Check if we changed batch
193-
if (i0 == sum_qb + q_batches[b] && b < s_batches.size()){
194-
sum_qb += q_batches[b];
195-
sum_sb += s_batches[b];
196192

197-
b++;
193+
if (i0 == q_batches[b+1] && b < (int)s_batches.size()-1 && b < (int)q_batches.size()-1){
198194

199195
// Change the points
196+
b++;
200197
current_cloud.pts.clear();
201-
current_cloud.set_batch(supports, sum_sb, s_batches[b]);
198+
if(s_batches[b] < s_batches[b+1])
199+
current_cloud.set_batch(supports, s_batches[b], s_batches[b+1]);
202200
// Build KDTree of the current element of the batch
203201
delete index;
204202

205203
index = new my_kd_tree_t(3, current_cloud, tree_params);
206204
index->buildIndex();
205+
207206
}
207+
208208
// Initial guess of neighbors size
209209

210210

@@ -233,22 +233,19 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
233233

234234
dists.resize(query_pcd.pts.size() * max_count);
235235
i0 = 0;
236-
sum_sb = 0;
237-
sum_qb = 0;
236+
238237
b = 0;
239238

240239
for (auto& inds_dists : all_inds_dists){// Check if we changed batch
241240

242241

243-
if (i0 == sum_qb + q_batches[b]){
244-
sum_qb += q_batches[b];
245-
sum_sb += s_batches[b];
242+
if (i0 == q_batches[b+1] && b < (int)s_batches.size()-1 && b < (int)q_batches.size()-1){
246243
b++;
247244
}
248245

249246
for (int j = 0; j < max_count; j++){
250-
if (j < inds_dists.size()){
251-
neighbors_indices[i0 * max_count + j] = inds_dists[j].first + sum_sb;
247+
if ((unsigned int)j < inds_dists.size()){
248+
neighbors_indices[i0 * max_count + j] = inds_dists[j].first + s_batches[b];
252249
dists[i0 * max_count + j] = (float) inds_dists[j].second;
253250
}
254251
else {
@@ -273,19 +270,15 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
273270
neighbors_indices.resize(size * 2);
274271
dists.resize(size);
275272
i0 = 0;
276-
sum_sb = 0;
277-
sum_qb = 0;
278273
b = 0;
279274
int u = 0;
280275
for (auto& inds_dists : all_inds_dists){
281-
if (i0 == sum_qb + q_batches[b]){
282-
sum_qb += q_batches[b];
283-
sum_sb += s_batches[b];
276+
if (i0 == q_batches[b+1] && b < (int)s_batches.size()-1 && b < (int)q_batches.size()-1){
284277
b++;
285278
}
286279
for (int j = 0; j < max_count; j++){
287-
if (j < inds_dists.size()){
288-
neighbors_indices[u] = inds_dists[j].first + sum_sb;
280+
if ((unsigned int)j < inds_dists.size()){
281+
neighbors_indices[u] = inds_dists[j].first + s_batches[b];
289282
neighbors_indices[u + 1] = i0;
290283
dists[u/2] = (float) inds_dists[j].second;
291284
u += 2;

cpu/src/torch_nearest_neighbors.cpp

Lines changed: 37 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -59,87 +59,71 @@ std::pair<at::Tensor, at::Tensor> ball_query(at::Tensor query,
5959
return std::make_pair(out.clone(), out_dists.clone());
6060
}
6161

62-
void cumsum(const vector<long>& batch, vector<long>& res){
63-
64-
res.resize(batch[batch.size()-1]-batch[0]+2, 0);
65-
long ind = batch[0];
66-
long incr = 1;
67-
if(res.size() > 1){
68-
for(int i=1; i < batch.size(); i++){
69-
if(batch[i] == ind)
70-
incr++;
71-
else{
72-
res[ind-batch[0]+1] = incr;
73-
incr =1;
74-
ind = batch[i];
75-
}
76-
}
77-
78-
}
79-
res[ind-batch[0]+1] = incr;
62+
at::Tensor degree(at::Tensor row, int64_t num_nodes) {
63+
auto zero = at::zeros(num_nodes, row.options());
64+
auto one = at::ones(row.size(0), row.options());
65+
return zero.scatter_add_(0, row, one);
8066
}
8167

8268
std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor query,
8369
at::Tensor support,
8470
at::Tensor query_batch,
8571
at::Tensor support_batch,
8672
float radius, int max_num, int mode) {
87-
at::Tensor out;
88-
at::Tensor out_dists;
89-
auto data_qb = query_batch.DATA_PTR<long>();
90-
auto data_sb = support_batch.DATA_PTR<long>();
91-
std::vector<long> query_batch_stl = std::vector<long>(data_qb, data_qb+query_batch.size(0));
92-
std::vector<long> cumsum_query_batch_stl;
93-
cumsum(query_batch_stl, cumsum_query_batch_stl);
94-
95-
std::vector<long> support_batch_stl = std::vector<long>(data_sb, data_sb+support_batch.size(0));
96-
std::vector<long> cumsum_support_batch_stl;
97-
cumsum(support_batch_stl, cumsum_support_batch_stl);
73+
at::Tensor idx;
9874

75+
at::Tensor dist;
9976
std::vector<long> neighbors_indices;
77+
std::vector<float> neighbors_dists;
10078

10179
auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
10280
auto options_dist = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU);
81+
10382
int max_count = 0;
104-
std::vector<float> neighbors_dists;
83+
auto batch_access = query_batch.accessor<int64_t, 1>();
84+
auto batch_size = batch_access[-1] + 1;
85+
query_batch = degree(query_batch, batch_size);
86+
query_batch = at::cat({at::zeros(1, query_batch.options()), query_batch.cumsum(0)}, 0);
87+
support_batch = degree(support_batch, batch_size);
88+
support_batch = at::cat({at::zeros(1, support_batch.options()), support_batch.cumsum(0)}, 0);
89+
std::vector<long> query_batch_stl(query_batch.DATA_PTR<long>(), query_batch.DATA_PTR<long>() + query_batch.numel());
90+
std::vector<long> support_batch_stl(support_batch.DATA_PTR<long>(), support_batch.DATA_PTR<long>() + support_batch.numel());
91+
10592
AT_DISPATCH_ALL_TYPES(query.scalar_type(), "batch_radius_search", [&] {
10693

107-
auto data_q = query.DATA_PTR<scalar_t>();
108-
auto data_s = support.DATA_PTR<scalar_t>();
109-
std::vector<scalar_t> queries_stl = std::vector<scalar_t>(data_q,
110-
data_q + query.size(0)*query.size(1));
111-
std::vector<scalar_t> supports_stl = std::vector<scalar_t>(data_s,
112-
data_s + support.size(0)*support.size(1));
94+
std::vector<scalar_t> queries_stl(query.DATA_PTR<scalar_t>(), query.DATA_PTR<scalar_t>() + query.numel());
95+
std::vector<scalar_t> supports_stl(support.DATA_PTR<scalar_t>(), support.DATA_PTR<scalar_t>() + support.numel());
11396

11497

115-
max_count = batch_nanoflann_neighbors<scalar_t>(queries_stl,
98+
max_count = batch_nanoflann_neighbors<scalar_t>(queries_stl,
11699
supports_stl,
117-
cumsum_query_batch_stl,
118-
cumsum_support_batch_stl,
100+
query_batch_stl,
101+
support_batch_stl,
119102
neighbors_indices,
120103
neighbors_dists,
121104
radius,
122105
max_num,
123106
mode);
124-
});
125-
126-
long* neighbors_indices_ptr = neighbors_indices.data();
127-
auto neighbors_dists_ptr = neighbors_dists.data();
128107

129108

109+
});
110+
auto neighbors_dists_ptr = neighbors_dists.data();
111+
long* neighbors_indices_ptr = neighbors_indices.data();
130112
if(mode == 0){
131-
out = torch::from_blob(neighbors_indices_ptr, {query.size(0), max_count}, options=options);
132-
out_dists = torch::from_blob(neighbors_dists_ptr,
133-
{query.size(0), max_count},
134-
options=options_dist);
113+
idx = torch::from_blob(neighbors_indices_ptr, {query.size(0), max_count}, options=options);
114+
dist = torch::from_blob(neighbors_dists_ptr,
115+
{query.size(0), max_count},
116+
options=options_dist);
117+
135118
}
136-
else if(mode == 1){
137-
out = torch::from_blob(neighbors_indices_ptr, {(int)neighbors_indices.size()/2, 2}, options=options);
138-
out_dists = torch::from_blob(neighbors_dists_ptr,
139-
{(int)neighbors_indices.size()/2, 1},
140-
options=options_dist);
119+
else if(mode ==1){
120+
idx = torch::from_blob(neighbors_indices_ptr, {(int)neighbors_indices.size()/2, 2}, options=options);
121+
dist = torch::from_blob(neighbors_dists_ptr,
122+
{(int)neighbors_indices.size()/2, 1},
123+
options=options_dist);
141124
}
142-
return std::make_pair(out.clone(), out_dists.clone());
125+
return std::make_pair(idx.clone(), dist.clone());
126+
143127
}
144128

145129

test/test_ballquerry.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from torch_points import ball_query
44
import numpy.testing as npt
55
import numpy as np
6+
from sklearn.neighbors import KDTree
67

78
from . import run_if_cuda
89

@@ -54,23 +55,23 @@ def test_simple_gpu(self):
5455
npt.assert_array_almost_equal(idx, idx_answer)
5556
npt.assert_array_almost_equal(dist2, dist2_answer)
5657

57-
# def test_simple_cpu(self):
58-
# x = torch.tensor([[10, 0, 0], [0.1, 0, 0], [10, 0, 0], [0.1, 0, 0]]).to(torch.float)
59-
# y = torch.tensor([[0, 0, 0]]).to(torch.float)
58+
def test_simple_cpu(self):
59+
x = torch.tensor([[10, 0, 0], [0.1, 0, 0], [10, 0, 0], [0.1, 0, 0]]).to(torch.float)
60+
y = torch.tensor([[0, 0, 0]]).to(torch.float)
6061

61-
# batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long()
62-
# batch_y = torch.from_numpy(np.asarray([0])).long()
62+
batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long()
63+
batch_y = torch.from_numpy(np.asarray([0])).long()
6364

64-
# idx, dist2 = ball_query(1.0, 2, x, y, mode="PARTIAL_DENSE", batch_x=batch_x, batch_y=batch_y)
65+
idx, dist2 = ball_query(1.0, 2, x, y, mode="PARTIAL_DENSE", batch_x=batch_x, batch_y=batch_y)
6566

66-
# idx = idx.detach().cpu().numpy()
67-
# dist2 = dist2.detach().cpu().numpy()
67+
idx = idx.detach().cpu().numpy()
68+
dist2 = dist2.detach().cpu().numpy()
6869

69-
# idx_answer = np.asarray([[1, 1], [0, 1], [1, 1], [1, 1]])
70-
# dist2_answer = np.asarray([[-1, -1], [0.01, -1], [-1, -1], [-1, -1]]).astype(np.float32)
70+
idx_answer = np.asarray([[1, 1], [0, 1], [1, 1], [1, 1]])
71+
dist2_answer = np.asarray([[-1, -1], [0.01, -1], [-1, -1], [-1, -1]]).astype(np.float32)
7172

72-
# npt.assert_array_almost_equal(idx, idx_answer)
73-
# npt.assert_array_almost_equal(dist2, dist2_answer)
73+
npt.assert_array_almost_equal(idx, idx_answer)
74+
npt.assert_array_almost_equal(dist2, dist2_answer)
7475

7576
def test_random_cpu(self):
7677
a = torch.randn(1000, 3).to(torch.float)
@@ -80,6 +81,19 @@ def test_random_cpu(self):
8081
idx, dist = ball_query(1.0, 12, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b)
8182
idx2, dist2 = ball_query(1.0, 12, b, a, mode="PARTIAL_DENSE", batch_x=batch_b, batch_y=batch_a)
8283

84+
zeros = torch.zeros_like(batch_b)
85+
idx3, dist3 = ball_query(0.5, 17, b, b, mode="PARTIAL_DENSE", batch_x=zeros, batch_y=zeros)
86+
87+
88+
# Comparison to see if we have the same result
89+
tree = KDTree(b.detach().numpy())
90+
idx3_sk = tree.query_radius(b.detach().numpy(), r=0.5)
91+
i = np.random.randint(len(batch_b))
92+
for p in idx3[i].detach().numpy():
93+
if p < len(batch_b):
94+
assert p in idx3_sk[i]
95+
96+
8397

8498
if __name__ == "__main__":
8599
unittest.main()

0 commit comments

Comments
 (0)