Skip to content

Commit 340677a

Browse files
Knn interpolate on CPU
1 parent abc4db9 commit 340677a

File tree

16 files changed

+200
-104
lines changed

16 files changed

+200
-104
lines changed

cpu/include/interpolate.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#pragma once
2+
#include <torch/extension.h>
3+
4+
at::Tensor knn_interpolate(at::Tensor features, at::Tensor idx, at::Tensor weight);
5+
6+
at::Tensor knn_interpolate_grad(at::Tensor grad_out, at::Tensor idx, at::Tensor weight,
7+
const int m);

cpu/include/knn.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
#pragma once
22
#include <torch/extension.h>
3-
std::pair<at::Tensor, at::Tensor> dense_knn(at::Tensor query, at::Tensor support, int k);
3+
std::pair<at::Tensor, at::Tensor> dense_knn(at::Tensor support, at::Tensor query, int k);

cpu/include/utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33

44
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be a CPU tensor")
55

6-
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be a contiguous tensor")
6+
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be a contiguous tensor")

cpu/src/bindings.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
#include "ball_query.h"
2+
// #include "fps.h"
3+
#include "interpolate.h"
24
#include "knn.h"
35

46
using namespace pybind11::literals;
57

68
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
79
{
8-
m.def("dense_knn", &dense_knn,"", "support"_a, "querry"_a, "k"_a);
10+
m.def("dense_knn", &dense_knn, "", "support"_a, "querry"_a, "k"_a);
11+
m.def("knn_interpolate", &knn_interpolate, "", "features"_a, "idx"_a, "weights"_a);
12+
m.def("knn_interpolate_grad", &knn_interpolate_grad, "", "grad_out"_a, "idx"_a, "weights"_a,
13+
"m"_a);
914

1015
m.def("ball_query", &ball_query,
1116
"compute the radius search of a point cloud using nanoflann"

cpu/src/interpolate.cpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#include "compat.h"
2+
#include "utils.h"
3+
#include <iostream>
4+
#include <torch/extension.h>
5+
6+
at::Tensor knn_interpolate(at::Tensor features, at::Tensor idx, at::Tensor weight)
7+
{
8+
CHECK_CONTIGUOUS(features);
9+
CHECK_CONTIGUOUS(idx);
10+
CHECK_CONTIGUOUS(weight);
11+
CHECK_CPU(idx);
12+
CHECK_CPU(features);
13+
CHECK_CPU(weight);
14+
15+
at::Tensor output = torch::zeros({features.size(0), features.size(1), idx.size(1)},
16+
at::device(features.device()).dtype(features.scalar_type()));
17+
18+
AT_DISPATCH_ALL_TYPES(features.scalar_type(), "knn_interpolate", [&] {
19+
auto output_a = output.accessor<scalar_t, 3>();
20+
auto features_a = features.accessor<scalar_t, 3>();
21+
auto weight_a = weight.accessor<scalar_t, 3>();
22+
auto idx_a = idx.accessor<long, 3>();
23+
24+
auto batch_size = idx.size(0);
25+
for (auto b = 0; b < batch_size; b++)
26+
{
27+
for (auto p = 0; p < idx.size(1); p++)
28+
{
29+
for (auto c = 0; c < features.size(1); c++)
30+
{
31+
output_a[b][c][p] = 0;
32+
for (int i = 0; i < idx.size(2); i++)
33+
output_a[b][c][p] += features_a[b][c][idx_a[b][p][i]] * weight_a[b][p][i];
34+
}
35+
}
36+
}
37+
});
38+
return output;
39+
}
40+
41+
at::Tensor knn_interpolate_grad(at::Tensor grad_out, at::Tensor idx, at::Tensor weight, const int m)
42+
{
43+
CHECK_CPU(grad_out);
44+
at::Tensor output = torch::zeros({grad_out.size(0), grad_out.size(1), m},
45+
at::device(grad_out.device()).dtype(grad_out.scalar_type()));
46+
47+
AT_DISPATCH_ALL_TYPES(grad_out.scalar_type(), "knn_interpolate_grad", [&] {
48+
auto output_a = output.accessor<scalar_t, 3>();
49+
auto grad_out_a = grad_out.accessor<scalar_t, 3>();
50+
auto weight_a = weight.accessor<scalar_t, 3>();
51+
auto idx_a = idx.accessor<long, 3>();
52+
53+
auto batch_size = idx.size(0);
54+
for (auto b = 0; b < batch_size; b++)
55+
{
56+
for (auto p = 0; p < idx.size(1); p++)
57+
{
58+
for (auto c = 0; c < grad_out.size(1); c++)
59+
{
60+
for (int i = 0; i < idx.size(2); i++)
61+
{
62+
auto new_idx = idx_a[b][p][i];
63+
output_a[b][c][new_idx] += grad_out_a[b][c][p] * weight_a[b][p][i];
64+
}
65+
}
66+
}
67+
}
68+
});
69+
return output;
70+
}

cpu/src/knn.cpp

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,22 @@
1-
#include "ball_query.h"
21
#include "compat.h"
32
#include "neighbors.cpp"
43
#include "neighbors.h"
54
#include "utils.h"
65
#include <iostream>
76
#include <torch/extension.h>
87

9-
108
std::pair<at::Tensor, at::Tensor> _single_batch_knn(at::Tensor support, at::Tensor query, int k)
119
{
1210
CHECK_CONTIGUOUS(support);
1311
CHECK_CONTIGUOUS(query);
1412
if (support.size(0) < k)
15-
TORCH_CHECK(false, "Not enough points in support to find "+ std::to_string(k) + " neighboors")
16-
17-
at::Tensor out;
18-
at::Tensor out_dists;
19-
std::vector<long> neighbors_indices(query.size(0), -1);
20-
std::vector<float> neighbors_dists(query.size(0), -1);
13+
TORCH_CHECK(false,
14+
"Not enough points in support to find " + std::to_string(k) + " neighboors")
15+
std::vector<long> neighbors_indices(query.size(0) * k, -1);
16+
std::vector<float> neighbors_dists(query.size(0) * k, -1);
2117

2218
auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
23-
auto options_dist = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU);
24-
19+
auto options_dist = torch::TensorOptions().dtype(query.scalar_type()).device(torch::kCPU);
2520
AT_DISPATCH_ALL_TYPES(query.scalar_type(), "knn", [&] {
2621
auto data_q = query.DATA_PTR<scalar_t>();
2722
auto data_s = support.DATA_PTR<scalar_t>();
@@ -31,12 +26,13 @@ std::pair<at::Tensor, at::Tensor> _single_batch_knn(at::Tensor support, at::Tens
3126
std::vector<scalar_t>(data_s, data_s + support.size(0) * support.size(1));
3227

3328
nanoflann_knn_neighbors<scalar_t>(queries_stl, supports_stl, neighbors_indices,
34-
neighbors_dists, k);
29+
neighbors_dists, k);
3530
});
3631
auto neighbors_dists_ptr = neighbors_dists.data();
3732
long* neighbors_indices_ptr = neighbors_indices.data();
38-
out = torch::from_blob(neighbors_indices_ptr, {query.size(0), k}, options = options);
39-
out_dists = torch::from_blob(neighbors_dists_ptr, {query.size(0), k}, options = options_dist);
33+
auto out = torch::from_blob(neighbors_indices_ptr, {query.size(0), k}, options = options);
34+
auto out_dists =
35+
torch::from_blob(neighbors_dists_ptr, {query.size(0), k}, options = options_dist);
4036

4137
return std::make_pair(out.clone(), out_dists.clone());
4238
}
@@ -45,6 +41,8 @@ std::pair<at::Tensor, at::Tensor> dense_knn(at::Tensor support, at::Tensor query
4541
{
4642
CHECK_CONTIGUOUS(support);
4743
CHECK_CONTIGUOUS(query);
44+
CHECK_CPU(query);
45+
CHECK_CPU(support);
4846

4947
int b = query.size(0);
5048
vector<at::Tensor> batch_idx;

cpu/src/neighbors.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@ int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
128128
return max_count;
129129
}
130130

131-
132131
template <typename scalar_t>
133132
int batch_nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
134133
vector<long>& q_batches, vector<long>& s_batches,
@@ -283,11 +282,10 @@ int batch_nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& suppo
283282

284283
template <typename scalar_t>
285284
void nanoflann_knn_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
286-
vector<long>& neighbors_indices, vector<float>& dists, int k)
285+
vector<long>& neighbors_indices, vector<float>& dists, int k)
287286
{
288287
// Nanoflann related variables
289288
// ***************************
290-
291289
// CLoud variable
292290
PointCloud<scalar_t> pcd;
293291
pcd.set(supports);
@@ -315,12 +313,10 @@ void nanoflann_knn_neighbors(vector<scalar_t>& queries, vector<scalar_t>& suppor
315313
// Find neighbors
316314
scalar_t query_pt[3] = {p0.x, p0.y, p0.z};
317315
std::vector<size_t> ret_index(k);
318-
std::vector<scalar_t> out_dist_sqr(k);
316+
std::vector<scalar_t> out_dist_sqr(k);
319317

320-
const size_t nMatches =
321-
index->knnSearch(&query_pt[0], k, &ret_index[0], &out_dist_sqr[0]);
322-
323-
for (size_t i=0; i < nMatches; i++)
318+
const size_t nMatches = index->knnSearch(&query_pt[0], k, &ret_index[0], &out_dist_sqr[0]);
319+
for (size_t i = 0; i < nMatches; i++)
324320
{
325321
neighbors_indices[i + current_pos] = ret_index[i];
326322
dists[i + current_pos] = out_dist_sqr[i];

cuda/include/ball_query.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#pragma once
22
#include <torch/extension.h>
33

4-
at::Tensor ball_query_dense(at::Tensor new_xyz, at::Tensor xyz, const float radius,
5-
const int nsample);
4+
std::pair<at::Tensor, at::Tensor> ball_query_dense(at::Tensor new_xyz, at::Tensor xyz,
5+
const float radius, const int nsample);
66

77
std::pair<at::Tensor, at::Tensor> ball_query_partial_dense(at::Tensor x, at::Tensor y,
88
at::Tensor batch_x, at::Tensor batch_y,

cuda/src/ball_query.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@
33
#include "utils.h"
44

55
void query_ball_point_kernel_dense_wrapper(int b, int n, int m, float radius, int nsample,
6-
const float* new_xyz, const float* xyz, int* idx, float* dist_out);
6+
const float* new_xyz, const float* xyz, int* idx,
7+
float* dist_out);
78

89
void query_ball_point_kernel_partial_wrapper(long batch_size, int size_x, int size_y, float radius,
910
int nsample, const float* x, const float* y,
1011
const long* batch_x, const long* batch_y,
1112
long* idx_out, float* dist_out);
1213

13-
at::Tensor ball_query_dense(at::Tensor new_xyz, at::Tensor xyz, const float radius,
14-
const int nsample)
14+
std::pair<at::Tensor, at::Tensor> ball_query_dense(at::Tensor new_xyz, at::Tensor xyz,
15+
const float radius, const int nsample)
1516
{
1617
CHECK_CONTIGUOUS(new_xyz);
1718
CHECK_CONTIGUOUS(xyz);
@@ -25,20 +26,19 @@ at::Tensor ball_query_dense(at::Tensor new_xyz, at::Tensor xyz, const float radi
2526

2627
at::Tensor idx = torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample},
2728
at::device(new_xyz.device()).dtype(at::ScalarType::Int));
28-
at::Tensor dist =
29-
torch::full({new_xyz.size(0), new_xyz.size(1), nsample}, -1, at::device(new_xyz.device()).dtype(at::ScalarType::Float));
29+
at::Tensor dist = torch::full({new_xyz.size(0), new_xyz.size(1), nsample}, -1,
30+
at::device(new_xyz.device()).dtype(at::ScalarType::Float));
3031

3132
if (new_xyz.type().is_cuda())
3233
{
33-
query_ball_point_kernel_dense_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), radius,
34-
nsample, new_xyz.DATA_PTR<float>(),
35-
xyz.DATA_PTR<float>(), idx.DATA_PTR<int>(), dist.DATA_PTR<int>());
34+
query_ball_point_kernel_dense_wrapper(
35+
xyz.size(0), xyz.size(1), new_xyz.size(1), radius, nsample, new_xyz.DATA_PTR<float>(),
36+
xyz.DATA_PTR<float>(), idx.DATA_PTR<int>(), dist.DATA_PTR<float>());
3637
}
3738
else
3839
{
3940
TORCH_CHECK(false, "CPU not supported");
4041
}
41-
4242
return std::make_pair(idx, dist);
4343
}
4444

cuda/src/ball_query_gpu.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ __global__ void query_ball_point_kernel_dense(int b, int n, int m, float radius,
1616
xyz += batch_index * n * 3;
1717
new_xyz += batch_index * m * 3;
1818
idx_out += m * nsample * batch_index;
19+
dist_out += m * nsample * batch_index;
1920

2021
int index = threadIdx.x;
2122
int stride = blockDim.x;
@@ -43,7 +44,7 @@ __global__ void query_ball_point_kernel_dense(int b, int n, int m, float radius,
4344
}
4445
}
4546
idx_out[j * nsample + cnt] = k;
46-
dist_out[j * nsample + cnt] = d2
47+
dist_out[j * nsample + cnt] = d2;
4748
++cnt;
4849
}
4950
}

0 commit comments

Comments
 (0)