Skip to content

Commit 6739a3f

Browse files
fps on cpu
1 parent 31c06d8 commit 6739a3f

File tree

5 files changed

+68
-2
lines changed

5 files changed

+68
-2
lines changed

cpu/include/fps.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#pragma once
2+
#include <torch/extension.h>
3+
at::Tensor fps(at::Tensor points, const int nsamples, bool random = true);

cpu/src/bindings.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include "ball_query.h"
2-
// #include "fps.h"
2+
#include "fps.h"
33
#include "interpolate.h"
44
#include "knn.h"
55

@@ -11,6 +11,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
1111
m.def("knn_interpolate", &knn_interpolate, "", "features"_a, "idx"_a, "weights"_a);
1212
m.def("knn_interpolate_grad", &knn_interpolate_grad, "", "grad_out"_a, "idx"_a, "weights"_a,
1313
"m"_a);
14+
m.def("fps", &fps, "", "points"_a, "num_samples"_a, "random"_a);
1415

1516
m.def("ball_query", &ball_query,
1617
"compute the radius search of a point cloud using nanoflann"

cpu/src/fps.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#include <torch/extension.h>
2+
3+
#include "compat.h"
4+
#include "utils.h"
5+
6+
at::Tensor get_dist(at::Tensor x, ptrdiff_t index)
7+
{
8+
return (x - x[index]).norm(2, 1);
9+
}
10+
11+
at::Tensor fps(at::Tensor points, const int nsamples, bool random)
12+
{
13+
CHECK_CONTIGUOUS(points);
14+
15+
auto out_options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
16+
auto batch_size = points.size(0);
17+
auto out = torch::empty({batch_size, nsamples}, out_options);
18+
auto out_a = out.accessor<long, 2>();
19+
20+
for (ptrdiff_t b = 0; b < batch_size; b++)
21+
{
22+
auto y = points[b];
23+
ptrdiff_t start = 0;
24+
if (random)
25+
start = at::randperm(y.size(0), out_options).DATA_PTR<int64_t>()[0];
26+
27+
out_a[b][0] = start;
28+
auto dist = get_dist(y, start);
29+
for (ptrdiff_t i = 1; i < nsamples; i++)
30+
{
31+
ptrdiff_t argmax = dist.argmax().DATA_PTR<int64_t>()[0];
32+
out_a[b][i] = argmax;
33+
dist = at::min(dist, get_dist(y, argmax));
34+
}
35+
}
36+
return out;
37+
}

test/test_fps.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import unittest
2+
import torch
3+
import os
4+
import sys
5+
6+
ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")
7+
sys.path.insert(0, ROOT)
8+
9+
from torch_points.points_cpu import fps
10+
11+
12+
class TestFps(unittest.TestCase):
13+
def test_simplecpu(self):
14+
points = torch.tensor([[[0, 0, 0], [1, 0, 0], [2, 0, 0]], [[-1, 1, 0], [0, 0, 10], [0, 0, 2]]]).float()
15+
idx = fps(points, 2, False)
16+
torch.testing.assert_allclose(idx, torch.tensor([[0, 2], [0, 1]]))
17+
18+
def test_random(self):
19+
points = torch.randn(10, 100, 3)
20+
idx = fps(points, 2, True)
21+
self.assertNotEqual(idx[0][0], 0)
22+
23+
24+
if __name__ == "__main__":
25+
unittest.main()

torch_points/torchpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def forward(ctx, xyz, npoint):
1717
if xyz.is_cuda:
1818
return tpcuda.furthest_point_sampling(xyz, npoint)
1919
else:
20-
raise NotImplementedError
20+
return tpcpu.fps(xyz, npoint, True)
2121

2222
@staticmethod
2323
def backward(xyz, a=None):

0 commit comments

Comments
 (0)