Skip to content

Commit 4cf6eff

Browse files
KNN for dense data
1 parent c12403e commit 4cf6eff

File tree

9 files changed

+154
-0
lines changed

9 files changed

+154
-0
lines changed

cpu/include/knn.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#pragma once
2+
#include <torch/extension.h>
3+
std::pair<at::Tensor, at::Tensor> dense_knn(at::Tensor query, at::Tensor support, int k);

cpu/include/neighbors.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,7 @@ int batch_nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& suppo
1717
vector<long>& q_batches, vector<long>& s_batches,
1818
vector<long>& neighbors_indices, vector<float>& dists, float radius,
1919
int max_num, int mode);
20+
21+
template <typename scalar_t>
22+
void nanoflann_knn_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
23+
vector<long>& neighbors_indices, vector<float>& dists, int k);
File renamed without changes.

cpu/src/bindings.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
#include "ball_query.h"
2+
#include "knn.h"
23

34
using namespace pybind11::literals;
45

56
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
67
{
8+
m.def("dense_knn", &dense_knn,"", "support"_a, "querry"_a, "k"_a);
9+
710
m.def("ball_query", &ball_query,
811
"compute the radius search of a point cloud using nanoflann"
912
"- support : a pytorch tensor of size N1 x 3, points where the "

cpu/src/knn.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#include "ball_query.h"
2+
#include "compat.h"
3+
#include "neighbors.cpp"
4+
#include "neighbors.h"
5+
#include "utils.h"
6+
#include <iostream>
7+
#include <torch/extension.h>
8+
9+
10+
std::pair<at::Tensor, at::Tensor> _single_batch_knn(at::Tensor support, at::Tensor query, int k)
11+
{
12+
CHECK_CONTIGUOUS(support);
13+
CHECK_CONTIGUOUS(query);
14+
if (support.size(0) < k)
15+
TORCH_CHECK(false, "Not enough points in support to find "+ std::to_string(k) + " neighboors")
16+
17+
at::Tensor out;
18+
at::Tensor out_dists;
19+
std::vector<long> neighbors_indices(query.size(0), -1);
20+
std::vector<float> neighbors_dists(query.size(0), -1);
21+
22+
auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
23+
auto options_dist = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU);
24+
25+
AT_DISPATCH_ALL_TYPES(query.scalar_type(), "knn", [&] {
26+
auto data_q = query.DATA_PTR<scalar_t>();
27+
auto data_s = support.DATA_PTR<scalar_t>();
28+
std::vector<scalar_t> queries_stl =
29+
std::vector<scalar_t>(data_q, data_q + query.size(0) * query.size(1));
30+
std::vector<scalar_t> supports_stl =
31+
std::vector<scalar_t>(data_s, data_s + support.size(0) * support.size(1));
32+
33+
nanoflann_knn_neighbors<scalar_t>(queries_stl, supports_stl, neighbors_indices,
34+
neighbors_dists, k);
35+
});
36+
auto neighbors_dists_ptr = neighbors_dists.data();
37+
long* neighbors_indices_ptr = neighbors_indices.data();
38+
out = torch::from_blob(neighbors_indices_ptr, {query.size(0), k}, options = options);
39+
out_dists = torch::from_blob(neighbors_dists_ptr, {query.size(0), k}, options = options_dist);
40+
41+
return std::make_pair(out.clone(), out_dists.clone());
42+
}
43+
44+
std::pair<at::Tensor, at::Tensor> dense_knn(at::Tensor support, at::Tensor query, int k)
45+
{
46+
CHECK_CONTIGUOUS(support);
47+
CHECK_CONTIGUOUS(query);
48+
49+
int b = query.size(0);
50+
vector<at::Tensor> batch_idx;
51+
vector<at::Tensor> batch_dist;
52+
for (int i = 0; i < b; i++)
53+
{
54+
auto out_pair = _single_batch_knn(support[i], query[i], k);
55+
batch_idx.push_back(out_pair.first);
56+
batch_dist.push_back(out_pair.second);
57+
}
58+
auto out_idx = torch::stack(batch_idx);
59+
auto out_dist = torch::stack(batch_dist);
60+
return std::make_pair(out_idx, out_dist);
61+
}

cpu/src/neighbors.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
128128
return max_count;
129129
}
130130

131+
131132
template <typename scalar_t>
132133
int batch_nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
133134
vector<long>& q_batches, vector<long>& s_batches,
@@ -279,3 +280,51 @@ int batch_nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& suppo
279280
}
280281
return max_count;
281282
}
283+
284+
template <typename scalar_t>
285+
void nanoflann_knn_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
286+
vector<long>& neighbors_indices, vector<float>& dists, int k)
287+
{
288+
// Nanoflann related variables
289+
// ***************************
290+
291+
// CLoud variable
292+
PointCloud<scalar_t> pcd;
293+
pcd.set(supports);
294+
// Cloud query
295+
PointCloud<scalar_t> pcd_query;
296+
pcd_query.set(queries);
297+
298+
// Tree parameters
299+
nanoflann::KDTreeSingleIndexAdaptorParams tree_params(15 /* max leaf */);
300+
301+
// KDTree type definition
302+
typedef nanoflann::KDTreeSingleIndexAdaptor<
303+
nanoflann::L2_Simple_Adaptor<scalar_t, PointCloud<scalar_t>>, PointCloud<scalar_t>, 3>
304+
my_kd_tree_t;
305+
306+
// Pointer to trees
307+
std::unique_ptr<my_kd_tree_t> index(new my_kd_tree_t(3, pcd, tree_params));
308+
index->buildIndex();
309+
310+
// Search neigbors indices
311+
// ***********************
312+
size_t current_pos = 0;
313+
for (auto& p0 : pcd_query.pts)
314+
{
315+
// Find neighbors
316+
scalar_t query_pt[3] = {p0.x, p0.y, p0.z};
317+
std::vector<size_t> ret_index(k);
318+
std::vector<scalar_t> out_dist_sqr(k);
319+
320+
const size_t nMatches =
321+
index->knnSearch(&query_pt[0], k, &ret_index[0], &out_dist_sqr[0]);
322+
323+
for (size_t i=0; i < nMatches; i++)
324+
{
325+
neighbors_indices[i + current_pos] = ret_index[i];
326+
dists[i + current_pos] = out_dist_sqr[i];
327+
}
328+
current_pos += nMatches;
329+
}
330+
}

test/test_knn.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import unittest
2+
import torch
3+
from torch_points import knn
4+
5+
class TestKnn(unittest.TestCase):
6+
def test_cpu(self):
7+
support = torch.tensor([[[0,0,0],[1,0,0],[2,0,0]]])
8+
query = torch.tensor([[[0,0,0]]])
9+
10+
idx, dist = knn(support, query, 3)
11+
torch.testing.assert_allclose(idx, torch.tensor([[[0,1, 2]]]))
12+
torch.testing.assert_allclose(dist, torch.tensor([[[0.,1., 4.]]]))
13+
14+
idx, dist = knn(support, query, 2)
15+
torch.testing.assert_allclose(idx, torch.tensor([[[0,1]]]))
16+
17+
with self.assertRaises(RuntimeError):
18+
knn(support,query, 5)

torch_points/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .torchpoints import *
2+
from .knn import *
23

34
__all__ = [
45
"ball_query",
@@ -7,4 +8,5 @@
78
"grouping_operation",
89
"three_interpolate",
910
"three_nn",
11+
"knn"
1012
]

torch_points/knn.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import torch_points.points_cpu as tpcpu
2+
3+
def knn(pos_support, pos, k):
4+
""" Dense knn serach
5+
Arguments:
6+
pos_support - [B,N,3] support points
7+
pos - [B,M,3] centre of queries
8+
k - number of neighboors, needs to be > N
9+
10+
Returns:
11+
idx - [B,M,k]
12+
dist2 - [B,M,k] squared distances
13+
"""
14+
return tpcpu.dense_knn(pos_support, pos, k)

0 commit comments

Comments
 (0)