Skip to content

Commit a0dd45b

Browse files
Adding option for sorting ball query results (off by default)
1 parent 5ee42e8 commit a0dd45b

File tree

7 files changed

+78
-68
lines changed

7 files changed

+78
-68
lines changed

cpu/include/ball_query.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
#pragma once
22
#include <torch/extension.h>
33
std::pair<at::Tensor, at::Tensor> ball_query(at::Tensor query, at::Tensor support, float radius,
4-
int max_num, int mode);
4+
int max_num, int mode, bool sorted);
55

66
std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor query, at::Tensor support,
77
at::Tensor query_batch, at::Tensor support_batch,
8-
float radius, int max_num, int mode);
8+
float radius, int max_num, int mode,
9+
bool sorted);
910

1011
std::pair<at::Tensor, at::Tensor> dense_ball_query(at::Tensor query, at::Tensor support,
11-
float radius, int max_num, int mode);
12+
float radius, int max_num, int mode,
13+
bool sorted);

cpu/include/neighbors.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@ using namespace std;
1010
template <typename scalar_t>
1111
int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
1212
vector<long>& neighbors_indices, vector<float>& dists, float radius,
13-
int max_num, int mode);
13+
int max_num, int mode, bool sorted);
1414

1515
template <typename scalar_t>
1616
int batch_nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
1717
vector<long>& q_batches, vector<long>& s_batches,
1818
vector<long>& neighbors_indices, vector<float>& dists, float radius,
19-
int max_num, int mode);
19+
int max_num, int mode, bool sorted);
2020

2121
template <typename scalar_t>
2222
void nanoflann_knn_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,

cpu/src/ball_query.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#include <torch/extension.h>
99

1010
std::pair<at::Tensor, at::Tensor> ball_query(at::Tensor support, at::Tensor query, float radius,
11-
int max_num, int mode)
11+
int max_num, int mode, bool sorted)
1212
{
1313
CHECK_CONTIGUOUS(support);
1414
CHECK_CONTIGUOUS(query);
@@ -31,7 +31,7 @@ std::pair<at::Tensor, at::Tensor> ball_query(at::Tensor support, at::Tensor quer
3131
std::vector<scalar_t>(data_s, data_s + support.size(0) * support.size(1));
3232

3333
max_count = nanoflann_neighbors<scalar_t>(queries_stl, supports_stl, neighbors_indices,
34-
neighbors_dists, radius, max_num, mode);
34+
neighbors_dists, radius, max_num, mode, sorted);
3535
});
3636
auto neighbors_dists_ptr = neighbors_dists.data();
3737
long* neighbors_indices_ptr = neighbors_indices.data();
@@ -62,7 +62,7 @@ at::Tensor degree(at::Tensor row, int64_t num_nodes)
6262

6363
std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor support, at::Tensor query,
6464
at::Tensor support_batch, at::Tensor query_batch,
65-
float radius, int max_num, int mode)
65+
float radius, int max_num, int mode, bool sorted)
6666
{
6767
CHECK_CONTIGUOUS(support);
6868
CHECK_CONTIGUOUS(query);
@@ -97,9 +97,9 @@ std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor support, at::Tenso
9797
std::vector<scalar_t> supports_stl(support.DATA_PTR<scalar_t>(),
9898
support.DATA_PTR<scalar_t>() + support.numel());
9999

100-
max_count = batch_nanoflann_neighbors<scalar_t>(queries_stl, supports_stl, query_batch_stl,
101-
support_batch_stl, neighbors_indices,
102-
neighbors_dists, radius, max_num, mode);
100+
max_count = batch_nanoflann_neighbors<scalar_t>(
101+
queries_stl, supports_stl, query_batch_stl, support_batch_stl, neighbors_indices,
102+
neighbors_dists, radius, max_num, mode, sorted);
103103
});
104104
auto neighbors_dists_ptr = neighbors_dists.data();
105105
long* neighbors_indices_ptr = neighbors_indices.data();
@@ -122,7 +122,7 @@ std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor support, at::Tenso
122122
}
123123

124124
std::pair<at::Tensor, at::Tensor> dense_ball_query(at::Tensor support, at::Tensor query,
125-
float radius, int max_num, int mode)
125+
float radius, int max_num, int mode, bool sorted)
126126
{
127127
CHECK_CONTIGUOUS(support);
128128
CHECK_CONTIGUOUS(query);
@@ -132,7 +132,7 @@ std::pair<at::Tensor, at::Tensor> dense_ball_query(at::Tensor support, at::Tenso
132132
vector<at::Tensor> batch_dist;
133133
for (int i = 0; i < b; i++)
134134
{
135-
auto out_pair = ball_query(query[i], support[i], radius, max_num, mode);
135+
auto out_pair = ball_query(query[i], support[i], radius, max_num, mode, sorted);
136136
batch_idx.push_back(out_pair.first);
137137
batch_dist.push_back(out_pair.second);
138138
}

cpu/src/bindings.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
2828
"maximum number of neighbors found if mode = 0, if mode=1 return a "
2929
"tensor of size Num_edge x 2 and return a tensor containing the "
3030
"squared distance of the neighbors",
31-
"support"_a, "querry"_a, "radius"_a, "max_num"_a = -1, "mode"_a = 0);
31+
"support"_a, "querry"_a, "radius"_a, "max_num"_a = -1, "mode"_a = 0, "sorted"_a = false);
3232

3333
m.def("batch_ball_query", &batch_ball_query,
3434
"compute the radius search of a point cloud for each batch using "
@@ -53,7 +53,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
5353
"tensor of size Num_edge x 2 and return a tensor containing the "
5454
"squared distance of the neighbors",
5555
"support"_a, "querry"_a, "query_batch"_a, "support_batch"_a, "radius"_a, "max_num"_a = -1,
56-
"mode"_a = 0);
56+
"mode"_a = 0, "sorted"_a = false);
5757
m.def("dense_ball_query", &dense_ball_query,
5858
"compute the radius search of a batch of point cloud using nanoflann"
5959
"- support : a pytorch tensor of size B x N1 x 3, points where the "
@@ -69,5 +69,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
6969
"maximum number of neighbors found if mode = 0, if mode=1 return a "
7070
"tensor of size Num_edge x 2 and return a tensor containing the "
7171
"squared distance of the neighbors",
72-
"support"_a, "querry"_a, "radius"_a, "max_num"_a = -1, "mode"_a = 0);
72+
"support"_a, "querry"_a, "radius"_a, "max_num"_a = -1, "mode"_a = 0, "sorted"_a = false);
7373
}

cpu/src/neighbors.cpp

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77
template <typename scalar_t>
88
int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
99
vector<long>& neighbors_indices, vector<float>& dists, float radius,
10-
int max_num, int mode)
10+
int max_num, int mode, bool sorted)
1111
{
1212
// Initiate variables
1313
// ******************
14+
std::random_device rd;
15+
std::mt19937 g(rd());
1416

1517
// square radius
16-
1718
const float search_radius = static_cast<float>(radius * radius);
1819

1920
// indices
@@ -48,7 +49,7 @@ int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
4849

4950
// Search params
5051
nanoflann::SearchParams search_params;
51-
search_params.sorted = false;
52+
search_params.sorted = sorted;
5253
std::vector<std::vector<std::pair<size_t, scalar_t>>> list_matches(pcd_query.pts.size());
5354

5455
for (auto& p0 : pcd_query.pts)
@@ -64,9 +65,8 @@ int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
6465
list_matches[i0] = {std::make_pair(0, -1)};
6566
else
6667
{
67-
std::random_device rd;
68-
std::mt19937 g(rd());
69-
std::shuffle(ret_matches.begin(), ret_matches.end(), g);
68+
if (!sorted)
69+
std::shuffle(ret_matches.begin(), ret_matches.end(), g);
7070
list_matches[i0] = ret_matches;
7171
}
7272
max_count = max(max_count, nMatches);
@@ -138,10 +138,13 @@ template <typename scalar_t>
138138
int batch_nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
139139
vector<long>& q_batches, vector<long>& s_batches,
140140
vector<long>& neighbors_indices, vector<float>& dists, float radius,
141-
int max_num, int mode)
141+
int max_num, int mode, bool sorted)
142142
{
143143
// Initiate variables
144144
// ******************
145+
std::random_device rd;
146+
std::mt19937 g(rd());
147+
145148
// indices
146149
int i0 = 0;
147150

@@ -179,7 +182,7 @@ int batch_nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& suppo
179182
// ***********************
180183
// Search params
181184
nanoflann::SearchParams search_params;
182-
search_params.sorted = false;
185+
search_params.sorted = sorted;
183186
for (auto& p0 : query_pcd.pts)
184187
{
185188
// Check if we changed batch
@@ -198,16 +201,18 @@ int batch_nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& suppo
198201
index->buildIndex();
199202
}
200203

201-
// Initial guess of neighbors size
202-
203-
all_inds_dists[i0].reserve(max_count);
204-
// Find neighbors
205-
// std::cerr << p0.x << p0.y << p0.z<<std::endl;
204+
// Find neighboors
205+
std::vector<std::pair<size_t, scalar_t>> ret_matches;
206+
ret_matches.reserve(max_count);
206207
scalar_t query_pt[3] = {p0.x, p0.y, p0.z};
208+
size_t nMatches = index->radiusSearch(query_pt, r2, ret_matches, search_params);
207209

208-
size_t nMatches = index->radiusSearch(query_pt, r2, all_inds_dists[i0], search_params);
209-
// Update max count
210+
// Shuffle if needed
211+
if (!sorted)
212+
std::shuffle(ret_matches.begin(), ret_matches.end(), g);
213+
all_inds_dists[i0] = ret_matches;
210214

215+
// Update max count
211216
if (nMatches > (size_t)max_count)
212217
max_count = nMatches;
213218
// Increment query idx

test/test_ballquerry.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@ def test_simple_gpu(self):
2020
def test_simple_cpu(self):
2121
a = torch.tensor([[[0, 0, 0], [1, 0, 0], [2, 0, 0]], [[0, 0, 0], [1, 0, 0], [2, 0, 0]]]).to(torch.float)
2222
b = torch.tensor([[[0, 0, 0]], [[3, 0, 0]]]).to(torch.float)
23-
idx, dist = ball_query(1.01, 2, a, b)
23+
idx, dist = ball_query(1.01, 2, a, b, sort=True)
2424
torch.testing.assert_allclose(idx, torch.tensor([[[0, 1]], [[2, 2]]]))
2525
torch.testing.assert_allclose(dist, torch.tensor([[[0, 1]], [[1, -1]]]).float())
2626

2727
a = torch.tensor([[[0, 0, 0], [1, 0, 0], [1, 1, 0]]]).to(torch.float)
28-
idx, dist = ball_query(1.01, 3, a, a)
28+
idx, dist = ball_query(1.01, 3, a, a, sort=True)
2929
torch.testing.assert_allclose(idx, torch.tensor([[[0, 1, 0], [1, 0, 2], [2, 1, 2]]]))
3030

3131
@run_if_cuda
@@ -70,7 +70,7 @@ def test_simple_gpu(self):
7070
dist2 = dist2.detach().cpu().numpy()
7171

7272
idx_answer = np.asarray([[1, -1]])
73-
dist2_answer = np.asarray([[0.0100, -1.0000]]).astype(np.float32)
73+
dist2_answer = np.asarray([[0.100, -1.0000]]).astype(np.float32)
7474

7575
npt.assert_array_almost_equal(idx, idx_answer)
7676
npt.assert_array_almost_equal(dist2, dist2_answer)
@@ -88,7 +88,7 @@ def test_simple_cpu(self):
8888
dist2 = dist2.detach().cpu().numpy()
8989

9090
idx_answer = np.asarray([[1, -1]])
91-
dist2_answer = np.asarray([[0.0100, -1.0000]]).astype(np.float32)
91+
dist2_answer = np.asarray([[0.100, -1.0000]]).astype(np.float32)
9292

9393
npt.assert_array_almost_equal(idx, idx_answer)
9494
npt.assert_array_almost_equal(dist2, dist2_answer)
@@ -100,9 +100,13 @@ def test_random_cpu(self):
100100
batch_b = torch.tensor([0 for i in range(b.shape[0] // 2)] + [1 for i in range(b.shape[0] // 2, b.shape[0])])
101101
R = 1
102102

103-
idx, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b)
104-
idx1, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b)
103+
idx, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=True)
104+
idx1, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=True)
105105
torch.testing.assert_allclose(idx1, idx)
106+
with self.assertRaises(AssertionError):
107+
idx, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=False)
108+
idx1, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=False)
109+
torch.testing.assert_allclose(idx1, idx)
106110

107111
self.assertEqual(idx.shape[0], b.shape[0])
108112
self.assertEqual(dist.shape[0], b.shape[0])

torch_points/torchpoints.py

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -146,33 +146,30 @@ def grouping_operation(features, idx):
146146
return grouped_features.reshape(idx.shape[0], features.shape[1], idx.shape[1], idx.shape[2])
147147

148148

149-
class BallQueryDense(Function):
150-
@staticmethod
151-
def forward(ctx, radius, nsample, xyz, new_xyz, batch_xyz=None, batch_new_xyz=None):
152-
# type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
153-
if new_xyz.is_cuda:
154-
return tpcuda.ball_query_dense(new_xyz, xyz, radius, nsample)
155-
else:
156-
return tpcpu.dense_ball_query(new_xyz, xyz, radius, nsample, mode=0)
157-
158-
@staticmethod
159-
def backward(ctx, a=None):
160-
return None, None, None, None
161-
162-
163-
class BallQueryPartialDense(Function):
164-
@staticmethod
165-
def forward(ctx, radius, nsample, x, y, batch_x, batch_y):
166-
# type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
167-
if x.is_cuda:
168-
return tpcuda.ball_query_partial_dense(x, y, batch_x, batch_y, radius, nsample)
169-
else:
170-
ind, dist = tpcpu.batch_ball_query(x, y, batch_x, batch_y, radius, nsample, mode=0)
171-
return ind, dist
172-
173-
@staticmethod
174-
def backward(ctx, a=None):
175-
return None, None, None, None
149+
def ball_query_dense(radius, nsample, xyz, new_xyz, batch_xyz=None, batch_new_xyz=None, sort=False):
150+
# type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
151+
if new_xyz.is_cuda:
152+
if sort:
153+
raise NotImplementedError("CUDA version does not sort the neighbors")
154+
ind, dist = tpcuda.ball_query_dense(new_xyz, xyz, radius, nsample)
155+
else:
156+
ind, dist = tpcpu.dense_ball_query(new_xyz, xyz, radius, nsample, mode=0, sorted=sort)
157+
positive = dist > 0
158+
dist[positive] = torch.sqrt(dist[positive])
159+
return ind, dist
160+
161+
162+
def ball_query_partial_dense(radius, nsample, x, y, batch_x, batch_y, sort=False):
163+
# type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
164+
if x.is_cuda:
165+
if sort:
166+
raise NotImplementedError("CUDA version does not sort the neighbors")
167+
ind, dist = tpcuda.ball_query_partial_dense(x, y, batch_x, batch_y, radius, nsample)
168+
else:
169+
ind, dist = tpcpu.batch_ball_query(x, y, batch_x, batch_y, radius, nsample, mode=0, sorted=sort)
170+
positive = dist > 0
171+
dist[positive] = torch.sqrt(dist[positive])
172+
return ind, dist
176173

177174

178175
def ball_query(
@@ -183,6 +180,7 @@ def ball_query(
183180
mode: Optional[str] = "dense",
184181
batch_x: Optional[torch.tensor] = None,
185182
batch_y: Optional[torch.tensor] = None,
183+
sort: Optional[bool] = False,
186184
) -> torch.Tensor:
187185
"""
188186
Arguments:
@@ -197,11 +195,12 @@ def ball_query(
197195
Keyword Arguments:
198196
batch_x -- (M, ) [partial_dense] or (B, M, 3) [dense] Contains indexes to indicate within batch it belongs to.
199197
batch_y -- (N, ) Contains indexes to indicate within batch it belongs to
198+
sort -- bool wether the neighboors are sorted or not (closests first)
200199
201200
Returns:
202201
idx: (npoint, nsample) or (B, npoint, nsample) [dense] It contains the indexes of the element within x at radius distance to y
203-
dist2: (N, nsample) or (B, npoint, nsample) Default value: -1.
204-
It contains the square distances of the element within x at radius distance to y
202+
dist: (N, nsample) or (B, npoint, nsample) Default value: -1.
203+
It contains the distance of the element within x at radius distance to y
205204
"""
206205
if mode is None:
207206
raise Exception('The mode should be defined within ["partial_dense | dense"]')
@@ -212,12 +211,12 @@ def ball_query(
212211
assert x.size(0) == batch_x.size(0)
213212
assert y.size(0) == batch_y.size(0)
214213
assert x.dim() == 2
215-
return BallQueryPartialDense.apply(radius, nsample, x, y, batch_x, batch_y)
214+
return ball_query_partial_dense(radius, nsample, x, y, batch_x, batch_y, sort=sort)
216215

217216
elif mode.lower() == "dense":
218217
if (batch_x is not None) or (batch_y is not None):
219218
raise Exception("batch_x and batch_y should not be provided")
220219
assert x.dim() == 3
221-
return BallQueryDense.apply(radius, nsample, x, y)
220+
return ball_query_dense(radius, nsample, x, y, sort=sort)
222221
else:
223222
raise Exception("unrecognized mode {}".format(mode))

0 commit comments

Comments
 (0)