Skip to content

Commit 68bf9b9

Browse files
Radius serach (#27)
* Optimise radius search * remove debug logs
1 parent 560745d commit 68bf9b9

File tree

6 files changed

+113
-60
lines changed

6 files changed

+113
-60
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Unreleased
2+
3+
- ball query returns squared distance instead of distance
4+
- leaner Point Cloud struct that avoids copying data

cpu/include/cloud.h

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -31,48 +31,29 @@
3131

3232
template <typename scalar_t> struct PointCloud
3333
{
34-
struct PointXYZ
35-
{
36-
scalar_t x, y, z;
37-
};
38-
39-
std::vector<PointXYZ> pts;
40-
4134
void set(const std::vector<scalar_t>& new_pts)
4235
{
43-
pts.clear();
44-
pts.resize(new_pts.size() / 3);
45-
for (unsigned int i = 0; i < new_pts.size(); i++)
46-
{
47-
if (i % 3 == 0)
48-
{
49-
PointXYZ point;
50-
point.x = new_pts[i];
51-
point.y = new_pts[i + 1];
52-
point.z = new_pts[i + 2];
53-
pts[i / 3] = point;
54-
}
55-
}
36+
pts = new_pts.data();
37+
length = new_pts.size() / 3;
5638
}
5739
void set_batch(const std::vector<scalar_t>& new_pts, int begin, int end)
5840
{
59-
int size = end - begin;
60-
pts.clear();
61-
pts.resize(size);
62-
for (int i = 0; i < size; i++)
63-
{
64-
PointXYZ point;
65-
point.x = new_pts[3 * (begin + i)];
66-
point.y = new_pts[3 * (begin + i) + 1];
67-
point.z = new_pts[3 * (begin + i) + 2];
68-
pts[i] = point;
69-
}
41+
pts = new_pts.data();
42+
int start = begin * 3;
43+
pts += start;
44+
length = (end - begin);
7045
}
7146

7247
// Must return the number of data points
7348
inline size_t kdtree_get_point_count() const
7449
{
75-
return pts.size();
50+
return get_point_count();
51+
}
52+
53+
// Must return the number of data points
54+
inline size_t get_point_count() const
55+
{
56+
return length;
7657
}
7758

7859
// Returns the dim'th component of the idx'th point in the class:
@@ -82,11 +63,11 @@ template <typename scalar_t> struct PointCloud
8263
inline scalar_t kdtree_get_pt(const size_t idx, const size_t dim) const
8364
{
8465
if (dim == 0)
85-
return pts[idx].x;
66+
return pts[idx * 3];
8667
else if (dim == 1)
87-
return pts[idx].y;
68+
return pts[idx * 3 + 1];
8869
else
89-
return pts[idx].z;
70+
return pts[idx * 3 + 2];
9071
}
9172

9273
// Optional bounding-box computation: return false to default to a standard
@@ -98,4 +79,29 @@ template <typename scalar_t> struct PointCloud
9879
{
9980
return false;
10081
}
82+
83+
const scalar_t* get_point_ptr(const int i) const
84+
{
85+
return pts + i * 3;
86+
}
87+
88+
std::array<scalar_t, 3> operator[](const size_t index) const
89+
{
90+
return {pts[index * 3], pts[index * 3 + 1], pts[index * 3 + 2]};
91+
}
92+
93+
private:
94+
const scalar_t* pts;
95+
size_t length;
10196
};
97+
98+
template <typename scalar_t>
99+
inline std::ostream& operator<<(std::ostream& os, const PointCloud<scalar_t>& P)
100+
{
101+
for (size_t i = 0; i < P.get_point_count(); i++)
102+
{
103+
auto p = P[i];
104+
os << "[" << p[0] << ", " << p[1] << ", " << p[2] << "];";
105+
}
106+
return os;
107+
}

cpu/src/neighbors.cpp

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Taken from https://github.com/HuguesTHOMAS/KPConv
33

44
#include "neighbors.h"
5+
#include <chrono>
56
#include <random>
67

78
template <typename scalar_t>
@@ -29,6 +30,7 @@ int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
2930
// CLoud variable
3031
PointCloud<scalar_t> pcd;
3132
pcd.set(supports);
33+
3234
// Cloud query
3335
PointCloud<scalar_t> pcd_query;
3436
pcd_query.set(queries);
@@ -50,17 +52,17 @@ int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
5052
// Search params
5153
nanoflann::SearchParams search_params;
5254
search_params.sorted = sorted;
53-
std::vector<std::vector<std::pair<size_t, scalar_t>>> list_matches(pcd_query.pts.size());
55+
auto num_query_points = pcd_query.get_point_count();
56+
std::vector<std::vector<std::pair<size_t, scalar_t>>> list_matches(num_query_points);
5457

55-
for (auto& p0 : pcd_query.pts)
58+
for (size_t i = 0; i < num_query_points; i++)
5659
{
5760
// Find neighbors
58-
scalar_t query_pt[3] = {p0.x, p0.y, p0.z};
5961
list_matches[i0].reserve(max_count);
6062
std::vector<std::pair<size_t, scalar_t>> ret_matches;
6163

62-
const size_t nMatches =
63-
index->radiusSearch(&query_pt[0], search_radius, ret_matches, search_params);
64+
const size_t nMatches = index->radiusSearch(pcd_query.get_point_ptr(i), search_radius,
65+
ret_matches, search_params);
6466
if (nMatches == 0)
6567
list_matches[i0] = {std::make_pair(0, -1)};
6668
else
@@ -164,10 +166,11 @@ int batch_nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& suppo
164166
PointCloud<scalar_t> current_cloud;
165167
PointCloud<scalar_t> query_pcd;
166168
query_pcd.set(queries);
167-
vector<vector<pair<size_t, scalar_t>>> all_inds_dists(query_pcd.pts.size());
169+
auto num_query_points = query_pcd.get_point_count();
170+
vector<vector<pair<size_t, scalar_t>>> all_inds_dists(num_query_points);
168171

169172
// Tree parameters
170-
nanoflann::KDTreeSingleIndexAdaptorParams tree_params(10 /* max leaf */);
173+
nanoflann::KDTreeSingleIndexAdaptorParams tree_params(15 /* max leaf */);
171174

172175
// KDTree type definition
173176
typedef nanoflann::KDTreeSingleIndexAdaptor<
@@ -178,34 +181,33 @@ int batch_nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& suppo
178181
current_cloud.set_batch(supports, s_batches[b], s_batches[b + 1]);
179182
std::unique_ptr<my_kd_tree_t> index(new my_kd_tree_t(3, current_cloud, tree_params));
180183
index->buildIndex();
184+
181185
// Search neigbors indices
182186
// ***********************
183187
// Search params
184188
nanoflann::SearchParams search_params;
185189
search_params.sorted = sorted;
186-
for (auto& p0 : query_pcd.pts)
190+
std::chrono::microseconds duration_search(0);
191+
for (size_t i = 0; i < num_query_points; i++)
187192
{
188193
// Check if we changed batch
189-
190194
if (i0 == q_batches[b + 1] && b < (int)s_batches.size() - 1 &&
191195
b < (int)q_batches.size() - 1)
192196
{
193197
// Change the points
194198
b++;
195-
current_cloud.pts.clear();
196199
if (s_batches[b] < s_batches[b + 1])
197200
current_cloud.set_batch(supports, s_batches[b], s_batches[b + 1]);
198201

199-
// Build KDTree of the current element of the batch
200202
index.reset(new my_kd_tree_t(3, current_cloud, tree_params));
201203
index->buildIndex();
202204
}
203205

204206
// Find neighboors
205207
std::vector<std::pair<size_t, scalar_t>> ret_matches;
206208
ret_matches.reserve(max_count);
207-
scalar_t query_pt[3] = {p0.x, p0.y, p0.z};
208-
size_t nMatches = index->radiusSearch(query_pt, r2, ret_matches, search_params);
209+
size_t nMatches =
210+
index->radiusSearch(query_pcd.get_point_ptr(i), r2, ret_matches, search_params);
209211

210212
// Shuffle if needed
211213
if (!sorted)
@@ -225,8 +227,8 @@ int batch_nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& suppo
225227
const int token = -1;
226228
if (mode == 0)
227229
{
228-
neighbors_indices.resize(query_pcd.pts.size() * max_count);
229-
dists.resize(query_pcd.pts.size() * max_count);
230+
neighbors_indices.resize(query_pcd.get_point_count() * max_count);
231+
dists.resize(query_pcd.get_point_count() * max_count);
230232
i0 = 0;
231233
b = 0;
232234

@@ -319,14 +321,15 @@ void nanoflann_knn_neighbors(vector<scalar_t>& queries, vector<scalar_t>& suppor
319321
// Search neigbors indices
320322
// ***********************
321323
size_t current_pos = 0;
322-
for (auto& p0 : pcd_query.pts)
324+
auto num_query_points = pcd_query.get_point_count();
325+
for (size_t i = 0; i < num_query_points; i++)
323326
{
324327
// Find neighbors
325-
scalar_t query_pt[3] = {p0.x, p0.y, p0.z};
326328
std::vector<size_t> ret_index(k);
327329
std::vector<scalar_t> out_dist_sqr(k);
328330

329-
const size_t nMatches = index->knnSearch(&query_pt[0], k, &ret_index[0], &out_dist_sqr[0]);
331+
const size_t nMatches =
332+
index->knnSearch(pcd_query.get_point_ptr(i), k, &ret_index[0], &out_dist_sqr[0]);
330333
for (size_t i = 0; i < nMatches; i++)
331334
{
332335
neighbors_indices[i + current_pos] = ret_index[i];

test/speed_radius.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch
2+
import os
3+
import sys
4+
import numpy.testing as npt
5+
import numpy as np
6+
from sklearn.neighbors import KDTree
7+
import unittest
8+
import time
9+
10+
ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")
11+
sys.path.insert(0, ROOT)
12+
13+
from torch_points_kernels import ball_query
14+
15+
16+
class TestRadiusSpeed(unittest.TestCase):
17+
def test_speed(self):
18+
start = time.time()
19+
a = torch.randn(50000, 3).to(torch.float)
20+
b = torch.randn(10000, 3).to(torch.float)
21+
batch_a = torch.tensor([0 for i in range(a.shape[0] // 2)] + [1 for i in range(a.shape[0] // 2, a.shape[0])])
22+
batch_b = torch.tensor([0 for i in range(b.shape[0] // 2)] + [1 for i in range(b.shape[0] // 2, b.shape[0])])
23+
R = 1
24+
samples = 50
25+
26+
idx, dist = ball_query(R, samples, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=True)
27+
idx1, dist = ball_query(R, samples, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=True)
28+
print(time.time() - start)
29+
torch.testing.assert_allclose(idx1, idx)
30+
31+
self.assertEqual(idx.shape[0], b.shape[0])
32+
self.assertEqual(dist.shape[0], b.shape[0])
33+
self.assertLessEqual(idx.max().item(), len(batch_a))
34+
35+
# # Comparison to see if we have the same result
36+
# tree = KDTree(a.detach().numpy())
37+
# idx3_sk = tree.query_radius(b.detach().numpy(), r=R)
38+
# i = np.random.randint(len(batch_b))
39+
# for p in idx[i].detach().numpy():
40+
# if p >= 0 and p < len(batch_a):
41+
# assert p in idx3_sk[i]
42+
43+
if __name__ == "__main__":
44+
unittest.main()

test/test_ballquerry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def test_simple_gpu(self):
7070
dist2 = dist2.detach().cpu().numpy()
7171

7272
idx_answer = np.asarray([[1, -1]])
73-
dist2_answer = np.asarray([[0.100, -1.0000]]).astype(np.float32)
73+
dist2_answer = np.asarray([[0.0100, -1.0000]]).astype(np.float32)
7474

7575
npt.assert_array_almost_equal(idx, idx_answer)
7676
npt.assert_array_almost_equal(dist2, dist2_answer)
@@ -88,7 +88,7 @@ def test_simple_cpu(self):
8888
dist2 = dist2.detach().cpu().numpy()
8989

9090
idx_answer = np.asarray([[1, -1]])
91-
dist2_answer = np.asarray([[0.100, -1.0000]]).astype(np.float32)
91+
dist2_answer = np.asarray([[0.0100, -1.0000]]).astype(np.float32)
9292

9393
npt.assert_array_almost_equal(idx, idx_answer)
9494
npt.assert_array_almost_equal(dist2, dist2_answer)

torch_points_kernels/torchpoints.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,6 @@ def ball_query_dense(radius, nsample, xyz, new_xyz, batch_xyz=None, batch_new_xy
154154
ind, dist = tpcuda.ball_query_dense(new_xyz, xyz, radius, nsample)
155155
else:
156156
ind, dist = tpcpu.dense_ball_query(new_xyz, xyz, radius, nsample, mode=0, sorted=sort)
157-
positive = dist > 0
158-
dist[positive] = torch.sqrt(dist[positive])
159157
return ind, dist
160158

161159

@@ -167,8 +165,6 @@ def ball_query_partial_dense(radius, nsample, x, y, batch_x, batch_y, sort=False
167165
ind, dist = tpcuda.ball_query_partial_dense(x, y, batch_x, batch_y, radius, nsample)
168166
else:
169167
ind, dist = tpcpu.batch_ball_query(x, y, batch_x, batch_y, radius, nsample, mode=0, sorted=sort)
170-
positive = dist > 0
171-
dist[positive] = torch.sqrt(dist[positive])
172168
return ind, dist
173169

174170

@@ -200,7 +196,7 @@ def ball_query(
200196
Returns:
201197
idx: (npoint, nsample) or (B, npoint, nsample) [dense] It contains the indexes of the element within x at radius distance to y
202198
dist: (N, nsample) or (B, npoint, nsample) Default value: -1.
203-
It contains the distance of the element within x at radius distance to y
199+
It contains the squared distance of the element within x at radius distance to y
204200
"""
205201
if mode is None:
206202
raise Exception('The mode should be defined within ["partial_dense | dense"]')

0 commit comments

Comments
 (0)