1+ #include " ball_query.h"
2+ #include " compat.h"
3+ #include " neighbors.cpp"
4+ #include " neighbors.h"
5+ #include " utils.h"
6+ #include < iostream>
7+ #include < torch/extension.h>
8+
9+
10+ std::pair<at::Tensor, at::Tensor> _single_batch_knn (at::Tensor support, at::Tensor query, int k)
11+ {
12+ CHECK_CONTIGUOUS (support);
13+ CHECK_CONTIGUOUS (query);
14+ if (support.size (0 ) < k)
15+ TORCH_CHECK (false , " Not enough points in support to find " + std::to_string (k) + " neighboors" )
16+
17+ at::Tensor out;
18+ at::Tensor out_dists;
19+ std::vector<long > neighbors_indices (query.size (0 ), -1 );
20+ std::vector<float > neighbors_dists (query.size (0 ), -1 );
21+
22+ auto options = torch::TensorOptions ().dtype (torch::kLong ).device (torch::kCPU );
23+ auto options_dist = torch::TensorOptions ().dtype (torch::kFloat32 ).device (torch::kCPU );
24+
25+ AT_DISPATCH_ALL_TYPES (query.scalar_type (), " knn" , [&] {
26+ auto data_q = query.DATA_PTR <scalar_t >();
27+ auto data_s = support.DATA_PTR <scalar_t >();
28+ std::vector<scalar_t > queries_stl =
29+ std::vector<scalar_t >(data_q, data_q + query.size (0 ) * query.size (1 ));
30+ std::vector<scalar_t > supports_stl =
31+ std::vector<scalar_t >(data_s, data_s + support.size (0 ) * support.size (1 ));
32+
33+ nanoflann_knn_neighbors<scalar_t >(queries_stl, supports_stl, neighbors_indices,
34+ neighbors_dists, k);
35+ });
36+ auto neighbors_dists_ptr = neighbors_dists.data ();
37+ long * neighbors_indices_ptr = neighbors_indices.data ();
38+ out = torch::from_blob (neighbors_indices_ptr, {query.size (0 ), k}, options = options);
39+ out_dists = torch::from_blob (neighbors_dists_ptr, {query.size (0 ), k}, options = options_dist);
40+
41+ return std::make_pair (out.clone (), out_dists.clone ());
42+ }
43+
44+ std::pair<at::Tensor, at::Tensor> dense_knn (at::Tensor support, at::Tensor query, int k)
45+ {
46+ CHECK_CONTIGUOUS (support);
47+ CHECK_CONTIGUOUS (query);
48+
49+ int b = query.size (0 );
50+ vector<at::Tensor> batch_idx;
51+ vector<at::Tensor> batch_dist;
52+ for (int i = 0 ; i < b; i++)
53+ {
54+ auto out_pair = _single_batch_knn (support[i], query[i], k);
55+ batch_idx.push_back (out_pair.first );
56+ batch_dist.push_back (out_pair.second );
57+ }
58+ auto out_idx = torch::stack (batch_idx);
59+ auto out_dist = torch::stack (batch_dist);
60+ return std::make_pair (out_idx, out_dist);
61+ }
0 commit comments