Skip to content

Commit 9c1355e

Browse files
Merge pull request #10 from nicolas-chaulet/singleapi
Singleapi
2 parents 2884bcc + 3c9a8e4 commit 9c1355e

File tree

6 files changed

+135
-149
lines changed

6 files changed

+135
-149
lines changed

cuda/include/cuda_utils.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,19 @@
1010

1111
#include <vector>
1212

13-
#define TOTAL_THREADS 1024
13+
#define TOTAL_THREADS_DENSE 512
14+
#define TOTAL_THREADS_SPARSE 1024
1415

1516
inline int opt_n_threads(int work_size) {
1617
const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);
1718

18-
return max(min(1 << pow_2, TOTAL_THREADS), 1);
19+
return max(min(1 << pow_2, TOTAL_THREADS_DENSE), 1);
1920
}
2021

2122
inline dim3 opt_block_config(int x, int y) {
2223
const int x_threads = opt_n_threads(x);
2324
const int y_threads =
24-
max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1);
25+
max(min(opt_n_threads(y), TOTAL_THREADS_DENSE / x_threads), 1);
2526
dim3 block_config(x_threads, y_threads, 1);
2627

2728
return block_config;

cuda/src/ball_query_gpu.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ __global__ void query_ball_point_kernel_partial_dense(int size_x,
6666
const ptrdiff_t end_idx_y = batch_y[batch_idx + 1];
6767
float radius2 = radius * radius;
6868

69-
for (ptrdiff_t n_x = start_idx_x + idx; n_x < end_idx_x; n_x += TOTAL_THREADS) {
69+
for (ptrdiff_t n_x = start_idx_x + idx; n_x < end_idx_x; n_x += TOTAL_THREADS_SPARSE) {
7070
int64_t count = 0;
7171
for (ptrdiff_t n_y = start_idx_y; n_y < end_idx_y; n_y++) {
7272
float dist = 0;
@@ -108,7 +108,7 @@ void query_ball_point_kernel_partial_wrapper(long batch_size,
108108
int64_t *idx_out,
109109
float *dist_out) {
110110

111-
query_ball_point_kernel_partial_dense<<<batch_size, TOTAL_THREADS>>>(
111+
query_ball_point_kernel_partial_dense<<<batch_size, TOTAL_THREADS_SPARSE>>>(
112112
size_x, size_y, radius, nsample, x, y,
113113
batch_x, batch_y, idx_out, dist_out);
114114

setup.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
from setuptools import setup, find_packages
2-
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, CppExtension
2+
from torch.utils.cpp_extension import (
3+
BuildExtension,
4+
CUDAExtension,
5+
CUDA_HOME,
6+
CppExtension,
7+
)
38
import glob
49

510
ext_src_root = "cuda"
@@ -33,12 +38,14 @@
3338
)
3439
)
3540

41+
requirements = ["torch^1.1.0"]
42+
3643
setup(
3744
name="torch_points",
38-
version="0.1.3",
45+
version="0.1.5",
3946
author="Nicolas Chaulet",
4047
packages=find_packages(),
41-
install_requires=[],
48+
install_requires=requirements,
4249
ext_modules=ext_modules,
4350
cmdclass={"build_ext": BuildExtension},
4451
)

test/test_ballquerry.py

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import unittest
22
import torch
3-
from torch_points import ball_query_dense
3+
from torch_points import ball_query
44
import numpy.testing as npt
55
import numpy as np
66

@@ -10,22 +10,88 @@ def test_simple_gpu(self):
1010
a = torch.tensor([[[0, 0, 0], [1, 0, 0], [2, 0, 0]]]).to(torch.float).cuda()
1111
b = torch.tensor([[[0, 0, 0]]]).to(torch.float).cuda()
1212

13-
npt.assert_array_equal(ball_query_dense(1, 2, a, b).detach().cpu().numpy(), np.array([[[0, 0]]]))
13+
npt.assert_array_equal(
14+
ball_query(1, 2, a, b).detach().cpu().numpy(), np.array([[[0, 0]]])
15+
)
1416

15-
def test_simple_cpu(self):
16-
a = torch.tensor([[[0, 0, 0], [1, 0, 0], [2, 0, 0]]]).to(torch.float)
17-
b = torch.tensor([[[0, 0, 0]]]).to(torch.float)
18-
npt.assert_array_equal(ball_query_dense(1, 2, a, b).detach().numpy(), np.array([[[0, 0]]]))
17+
def test_larger_gpu(self):
18+
a = torch.randn(32, 4096, 3).to(torch.float).cuda()
19+
idx = ball_query(1, 64, a, a).detach().cpu().numpy()
20+
self.assertGreaterEqual(idx.min(), 0)
1921

2022
def test_cpu_gpu_equality(self):
2123
a = torch.randn(5, 1000, 3)
22-
res_cpu = ball_query_dense(0.1, 17, a, a).detach().numpy()
23-
res_cuda = ball_query_dense(0.1, 17, a.cuda(), a.cuda()).cpu().detach().numpy()
24+
res_cpu = ball_query(0.1, 17, a, a).detach().numpy()
25+
res_cuda = ball_query(0.1, 17, a.cuda(), a.cuda()).cpu().detach().numpy()
2426
for i in range(a.shape[0]):
2527
for j in range(a.shape[1]):
2628
# Because it is not necessary the same order
2729
assert set(res_cpu[i][j]) == set(res_cuda[i][j])
2830

2931

32+
class TestBallPartial(unittest.TestCase):
33+
def test_simple_gpu(self):
34+
x = (
35+
torch.tensor([[10, 0, 0], [0.1, 0, 0], [10, 0, 0], [0.1, 0, 0]])
36+
.to(torch.float)
37+
.cuda()
38+
)
39+
y = torch.tensor([[0, 0, 0]]).to(torch.float).cuda()
40+
batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long().cuda()
41+
batch_y = torch.from_numpy(np.asarray([0])).long().cuda()
42+
43+
batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long().cuda()
44+
batch_y = torch.from_numpy(np.asarray([0])).long().cuda()
45+
46+
idx, dist2 = ball_query(
47+
1.0, 2, x, y, mode="PARTIAL_DENSE", batch_x=batch_x, batch_y=batch_y
48+
)
49+
50+
idx = idx.detach().cpu().numpy()
51+
dist2 = dist2.detach().cpu().numpy()
52+
53+
idx_answer = np.asarray([[1, 4]])
54+
dist2_answer = np.asarray([[0.0100, -1.0000]]).astype(np.float32)
55+
56+
npt.assert_array_almost_equal(idx, idx_answer)
57+
npt.assert_array_almost_equal(dist2, dist2_answer)
58+
59+
def test_simple_cpu(self):
60+
x = torch.tensor([[10, 0, 0], [0.1, 0, 0], [10, 0, 0], [0.1, 0, 0]]).to(
61+
torch.float
62+
)
63+
y = torch.tensor([[0, 0, 0]]).to(torch.float)
64+
65+
batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long()
66+
batch_y = torch.from_numpy(np.asarray([0])).long()
67+
68+
idx, dist2 = ball_query(
69+
1.0, 2, x, y, mode="PARTIAL_DENSE", batch_x=batch_x, batch_y=batch_y
70+
)
71+
72+
idx = idx.detach().cpu().numpy()
73+
dist2 = dist2.detach().cpu().numpy()
74+
75+
idx_answer = np.asarray([[1, 1], [0, 1], [1, 1], [1, 1]])
76+
dist2_answer = np.asarray([[-1, -1], [0.01, -1], [-1, -1], [-1, -1]]).astype(
77+
np.float32
78+
)
79+
80+
npt.assert_array_almost_equal(idx, idx_answer)
81+
npt.assert_array_almost_equal(dist2, dist2_answer)
82+
83+
def test_random_cpu(self):
84+
a = torch.randn(1000, 3).to(torch.float)
85+
b = torch.randn(1500, 3).to(torch.float)
86+
batch_a = torch.randint(1, (1000,)).sort(0)[0].long()
87+
batch_b = torch.randint(1, (1500,)).sort(0)[0].long()
88+
idx, dist = ball_query(
89+
1.0, 12, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b
90+
)
91+
idx2, dist2 = ball_query(
92+
1.0, 12, b, a, mode="PARTIAL_DENSE", batch_x=batch_b, batch_y=batch_a
93+
)
94+
95+
3096
if __name__ == "__main__":
3197
unittest.main()

test/test_ballquerry_partial.py

Lines changed: 0 additions & 69 deletions
This file was deleted.

torch_points/torchpoints.py

Lines changed: 45 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from torch.autograd import Function
33
import torch.nn as nn
44
import sys
5+
from typing import Optional
56

67
import torch_points.points_cpu as tpcpu
78

@@ -251,98 +252,78 @@ def forward(ctx, radius, nsample, xyz, new_xyz, batch_xyz=None, batch_new_xyz=No
251252
if new_xyz.is_cuda:
252253
return tpcuda.ball_query_dense(new_xyz, xyz, radius, nsample)
253254
else:
254-
ind, dist = tpcpu.dense_ball_query(new_xyz,
255-
xyz,
256-
radius, nsample, mode=0)
255+
ind, dist = tpcpu.dense_ball_query(new_xyz, xyz, radius, nsample, mode=0)
257256
return ind
258257

259258
@staticmethod
260259
def backward(ctx, a=None):
261260
return None, None, None, None
262261

263262

264-
def ball_query_dense(radius, nsample, xyz, new_xyz):
265-
r"""
266-
Parameters
267-
----------
268-
radius : float
269-
radius of the balls
270-
nsample : int
271-
maximum number of features in the balls
272-
xyz : torch.Tensor
273-
(B, N, 3) xyz coordinates of the features
274-
new_xyz : torch.Tensor
275-
(B, npoint, 3) centers of the ball query
276-
277-
Returns
278-
-------
279-
torch.Tensor
280-
(B, npoint, nsample) tensor with the indicies of the features that form the query balls
281-
"""
282-
return BallQueryDense.apply(radius, nsample, xyz, new_xyz)
283-
284-
285263
class BallQueryPartialDense(Function):
286264
@staticmethod
287265
def forward(ctx, radius, nsample, x, y, batch_x, batch_y):
288266
# type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
289267
if x.is_cuda:
290-
return tpcuda.ball_query_partial_dense(x, y,
291-
batch_x,
292-
batch_y,
293-
radius, nsample)
268+
return tpcuda.ball_query_partial_dense(
269+
x, y, batch_x, batch_y, radius, nsample
270+
)
294271
else:
295-
ind, dist = tpcpu.batch_ball_query(x, y,
296-
batch_x,
297-
batch_y,
298-
radius, nsample, mode=0)
272+
ind, dist = tpcpu.batch_ball_query(
273+
x, y, batch_x, batch_y, radius, nsample, mode=0
274+
)
299275
return ind, dist
300276

301277
@staticmethod
302278
def backward(ctx, a=None):
303279
return None, None, None, None
304280

305281

306-
def ball_query_partial_dense(radius, nsample, x, y, batch_x, batch_y):
307-
r"""
308-
Parameters
309-
----------
310-
radius : float
311-
radius of the balls
312-
nsample : int
313-
maximum number of features in the balls
314-
x : torch.Tensor
315-
(M, 3) xyz coordinates of the features (The neighbours are going to be looked for there)
316-
y : torch.Tensor
317-
(N, npoint, 3) centers of the ball query
318-
batch_x : torch.Tensor
319-
(M, ) Contains indexes to indicate within batch it belongs to.
320-
batch_y : torch.Tensor
321-
(N, ) Contains indexes to indicate within batch it belongs to
322-
323-
Returns
324-
-------
325-
torch.Tensor
326-
idx: (N, nsample) Default value: N. It contains the indexes of the element within y at radius distance to x
327-
dist2: (N, nsample) Default value: -1. It contains the square distances of the element within y at radius distance to x
282+
def ball_query(
283+
radius: float,
284+
nsample: int,
285+
x: torch.Tensor,
286+
y: torch.Tensor,
287+
mode: Optional[str] = "dense",
288+
batch_x: Optional[torch.tensor] = None,
289+
batch_y: Optional[torch.tensor] = None,
290+
) -> torch.Tensor:
291+
"""
292+
Arguments:
293+
radius {float} -- radius of the balls
294+
nsample {int} -- maximum number of features in the balls
295+
x {torch.Tensor} --
296+
(M, 3) [partial_dense] or (B, M, 3) [dense] xyz coordinates of the features
297+
y {torch.Tensor} --
298+
(npoint, 3) [partial_dense] or or (B, npoint, 3) [dense] centers of the ball query
299+
mode {str} -- switch between "dense" or "partial_dense" data layout
300+
301+
Keyword Arguments:
302+
batch_x -- (M, ) [partial_dense] or (B, M, 3) [dense] Contains indexes to indicate within batch it belongs to.
303+
batch_y -- (N, ) Contains indexes to indicate within batch it belongs to
304+
305+
306+
Returns:
307+
idx: (npoint, nsample) or (B, npoint, nsample) [dense] It contains the indexes of the element within x at radius distance to y
308+
OPTIONAL[partial_dense] dist2: (N, nsample) Default value: -1.
309+
It contains the square distances of the element within x at radius distance to y
328310
"""
329-
return BallQueryPartialDense.apply(radius, nsample, x, y, batch_x, batch_y)
330-
331-
332-
def ball_query(radius: float, nsample: int, x, y, batch_x=None, batch_y=None, mode=None):
333311
if mode is None:
334-
raise Exception('The mode should be defined within ["PARTIAL_DENSE | DENSE"]')
312+
raise Exception('The mode should be defined within ["partial_dense | dense"]')
335313

336314
if mode.lower() == "partial_dense":
337315
if (batch_x is None) or (batch_y is None):
338-
raise Exception('batch_x and batch_y should be provided')
316+
raise Exception("batch_x and batch_y should be provided")
339317
assert x.size(0) == batch_x.size(0)
340318
assert y.size(0) == batch_y.size(0)
341-
return ball_query_partial_dense(radius, nsample, x, y, batch_x, batch_y)
319+
assert x.dim() == 2
320+
return BallQueryPartialDense.apply(radius, nsample, x, y, batch_x, batch_y)
342321

343322
elif mode.lower() == "dense":
344323
if (batch_x is not None) or (batch_y is not None):
345-
raise Exception('batch_x and batch_y should not be provided')
346-
return ball_query_dense(radius, nsample, x, y)
324+
raise Exception("batch_x and batch_y should not be provided")
325+
assert x.dim() == 3
326+
return BallQueryDense.apply(radius, nsample, x, y)
347327
else:
348-
raise Exception('unrecognized mode {}'.format(mode))
328+
raise Exception("unrecognized mode {}".format(mode))
329+

0 commit comments

Comments
 (0)