Skip to content

Commit ebac5f9

Browse files
Merge pull request #1 from nicolas-chaulet/group_points_cpp
CPU Group points forward function
2 parents 766cdd6 + 84b28cc commit ebac5f9

File tree

8 files changed

+164
-17
lines changed

8 files changed

+164
-17
lines changed

.gitignore

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,8 @@
3232
*.app
3333

3434
build
35-
*.pyc
35+
*.pyc
36+
37+
.vscode/
38+
dist/
39+
torch_points.egg-info/

cpu/include/group_points.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#pragma once
2+
#include <torch/extension.h>
3+
4+
at::Tensor group_points(at::Tensor points, at::Tensor idx);
5+
at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n);

cpu/include/utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#pragma once
2+
#include <torch/extension.h>
3+
4+
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be a CPU tensor")
5+
6+
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be a contiguous tensor")

cpu/src/bindings.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include "group_points.h"
2+
3+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
4+
m.def("group_points", &group_points);
5+
}

cpu/src/group_points.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#include "group_points.h"
2+
#include "utils.h"
3+
4+
// input: points(b, c, n) idx(b, npoints, nsample)
5+
// output: out(b, c, npoints, nsample)
6+
at::Tensor group_points(at::Tensor points, at::Tensor idx) {
7+
CHECK_CPU(points);
8+
CHECK_CPU(idx);
9+
10+
at::Tensor output = torch::zeros(
11+
{points.size(0), points.size(1), idx.size(1), idx.size(2)},
12+
at::device(points.device()).dtype(at::ScalarType::Float)
13+
);
14+
15+
for (int batch_index = 0; batch_index < output.size(0); batch_index++) {
16+
for (int feat_index = 0; feat_index < output.size(1); feat_index++) {
17+
for (int point_index = 0; point_index < output.size(2); point_index++) {
18+
for (int sample_index = 0; sample_index < output.size(3); sample_index++) {
19+
output[batch_index][feat_index][point_index][sample_index]
20+
= points[batch_index][feat_index][
21+
idx[batch_index][point_index][sample_index]
22+
];
23+
}
24+
}
25+
}
26+
}
27+
28+
return output;
29+
}

setup.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from setuptools import setup, find_packages
2-
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
2+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, CppExtension
33
import glob
44

55
ext_src_root = "cuda"
@@ -20,6 +20,19 @@
2020
)
2121
)
2222

23+
cpu_ext_src_root = "cpu"
24+
cpu_ext_sources = glob.glob("{}/src/*.cpp".format(cpu_ext_src_root))
25+
26+
ext_modules.append(
27+
CppExtension(
28+
name="torch_points.points_cpu",
29+
sources=cpu_ext_sources,
30+
extra_compile_args={
31+
"cxx": ["-O2", "-I{}".format("{}/include".format(cpu_ext_src_root))],
32+
},
33+
)
34+
)
35+
2336
setup(
2437
name="torch_points",
2538
version="0.1.2",

test/test_grouping.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import unittest
2+
import torch
3+
import numpy as np
4+
import numpy.testing as npt
5+
from torch_points import grouping_operation
6+
7+
class TestGroup(unittest.TestCase):
8+
9+
# input: points(b, c, n) idx(b, npoints, nsample)
10+
# output: out(b, c, npoints, nsample)
11+
def test_simple(self):
12+
features = torch.tensor([
13+
[[0, 10, 0], [1, 11, 0], [2, 12, 0]],
14+
[
15+
[100, 110, 120], # x-coordinates
16+
[101, 111, 121], # y-coordinates
17+
[102, 112, 122], # z-coordinates
18+
]
19+
])
20+
idx = torch.tensor([
21+
[[1, 0], [0, 0]],
22+
[[0, 1], [1, 2]]
23+
])
24+
25+
expected = np.array([
26+
[
27+
[[10, 0], [0, 0]],
28+
[[11, 1], [1, 1]],
29+
[[12, 2], [2, 2]]
30+
],
31+
[ # 2nd batch
32+
[ # x-coordinates
33+
[100, 110], #x-coordinates of samples for point 0
34+
[110, 120], #x-coordinates of samples for point 1
35+
],
36+
[[101, 111], [111, 121]], # y-coordinates
37+
[[102, 112], [112, 122]], # z-coordinates
38+
]
39+
])
40+
41+
cpu_output = grouping_operation(features, idx).detach().cpu().numpy()
42+
43+
npt.assert_array_equal(expected, cpu_output)
44+
45+
if torch.cuda.is_available():
46+
npt.assert_array_equal(
47+
grouping_operation(
48+
features.cuda(),
49+
idx.cuda()
50+
).detach().cpu().numpy(), expected)
51+
52+
53+
if __name__ == '__main__':
54+
unittest.main()

torch_points/torchpoints.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,19 @@
33
import torch.nn as nn
44
import sys
55

6-
import torch_points.points_cuda as tpcuda
6+
import torch_points.points_cpu as tpcpu
7+
8+
if torch.cuda.is_available():
9+
import torch_points.points_cuda as tpcuda
710

811

912
class FurthestPointSampling(Function):
1013
@staticmethod
1114
def forward(ctx, xyz, npoint):
12-
return tpcuda.furthest_point_sampling(xyz, npoint)
15+
if xyz.is_cuda:
16+
return tpcuda.furthest_point_sampling(xyz, npoint)
17+
else:
18+
raise NotImplementedError
1319

1420
@staticmethod
1521
def backward(xyz, a=None):
@@ -45,14 +51,20 @@ def forward(ctx, features, idx):
4551

4652
ctx.for_backwards = (idx, C, N)
4753

48-
return tpcuda.gather_points(features, idx)
54+
if features.is_cuda:
55+
return tpcuda.gather_points(features, idx)
56+
else:
57+
return tpcpu.gather_points(features, idx)
4958

5059
@staticmethod
5160
def backward(ctx, grad_out):
5261
idx, C, N = ctx.for_backwards
5362

54-
grad_features = tpcuda.gather_points_grad(grad_out.contiguous(), idx, N)
55-
return grad_features, None
63+
if grad_out.is_cuda:
64+
grad_features = tpcuda.gather_points_grad(grad_out.contiguous(), idx, N)
65+
return grad_features, None
66+
else:
67+
raise NotImplementedError
5668

5769

5870
def gather_operation(features, idx):
@@ -64,12 +76,12 @@ def gather_operation(features, idx):
6476
(B, C, N) tensor
6577
6678
idx : torch.Tensor
67-
(B, npoint) tensor of the features to gather
79+
(B, npoint, nsample) tensor of the features to gather
6880
6981
Returns
7082
-------
7183
torch.Tensor
72-
(B, C, npoint) tensor
84+
(B, C, npoint, nsample) tensor
7385
"""
7486
return GatherOperation.apply(features, idx)
7587

@@ -78,7 +90,11 @@ class ThreeNN(Function):
7890
@staticmethod
7991
def forward(ctx, unknown, known):
8092
# type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
81-
dist2, idx = tpcuda.three_nn(unknown, known)
93+
94+
if unknown.is_cuda:
95+
dist2, idx = tpcuda.three_nn(unknown, known)
96+
else:
97+
raise NotImplementedError
8298

8399
return torch.sqrt(dist2), idx
84100

@@ -116,7 +132,10 @@ def forward(ctx, features, idx, weight):
116132

117133
ctx.three_interpolate_for_backward = (idx, weight, m)
118134

119-
return tpcuda.three_interpolate(features, idx, weight)
135+
if features.is_cuda:
136+
return tpcuda.three_interpolate(features, idx, weight)
137+
else:
138+
raise NotImplementedError
120139

121140
@staticmethod
122141
def backward(ctx, grad_out):
@@ -138,9 +157,12 @@ def backward(ctx, grad_out):
138157
"""
139158
idx, weight, m = ctx.three_interpolate_for_backward
140159

141-
grad_features = tpcuda.three_interpolate_grad(
142-
grad_out.contiguous(), idx, weight, m
143-
)
160+
if grad_out.is_cuda:
161+
grad_features = tpcuda.three_interpolate_grad(
162+
grad_out.contiguous(), idx, weight, m
163+
)
164+
else:
165+
raise NotImplementedError
144166

145167
return grad_features, None, None
146168

@@ -174,7 +196,10 @@ def forward(ctx, features, idx):
174196

175197
ctx.for_backwards = (idx, N)
176198

177-
return tpcuda.group_points(features, idx)
199+
if features.is_cuda:
200+
return tpcuda.group_points(features, idx)
201+
else:
202+
return tpcpu.group_points(features, idx)
178203

179204
@staticmethod
180205
def backward(ctx, grad_out):
@@ -194,7 +219,10 @@ def backward(ctx, grad_out):
194219
"""
195220
idx, N = ctx.for_backwards
196221

197-
grad_features = tpcuda.group_points_grad(grad_out.contiguous(), idx, N)
222+
if grad_out.is_cuda:
223+
grad_features = tpcuda.group_points_grad(grad_out.contiguous(), idx, N)
224+
else:
225+
raise NotImplementedError
198226

199227
return grad_features, None
200228

@@ -220,7 +248,10 @@ class BallQuery(Function):
220248
@staticmethod
221249
def forward(ctx, radius, nsample, xyz, new_xyz):
222250
# type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
223-
return tpcuda.ball_query(new_xyz, xyz, radius, nsample)
251+
if new_xyz.is_cuda:
252+
return tpcuda.ball_query(new_xyz, xyz, radius, nsample)
253+
else:
254+
raise NotImplementedError
224255

225256
@staticmethod
226257
def backward(ctx, a=None):

0 commit comments

Comments
 (0)