Skip to content

Commit 022b909

Browse files
Single API for ball querry
1 parent fdc285a commit 022b909

File tree

4 files changed

+59
-89
lines changed

4 files changed

+59
-89
lines changed

cuda/include/cuda_utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
#include <vector>
1212

13-
#define TOTAL_THREADS 1024
13+
#define TOTAL_THREADS 512
1414

1515
inline int opt_n_threads(int work_size) {
1616
const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);

test/test_ballquerry.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import unittest
22
import torch
3-
from torch_points import ball_query_dense
3+
from torch_points import ball_query
44
import numpy.testing as npt
55
import numpy as np
66

@@ -10,17 +10,42 @@ def test_simple_gpu(self):
1010
a = torch.tensor([[[0, 0, 0], [1, 0, 0], [2, 0, 0]]]).to(torch.float).cuda()
1111
b = torch.tensor([[[0, 0, 0]]]).to(torch.float).cuda()
1212

13-
npt.assert_array_equal(ball_query_dense(1, 2, a, b).detach().cpu().numpy(), np.array([[[0, 0]]]))
13+
npt.assert_array_equal(ball_query(1, 2, a, b).detach().cpu().numpy(), np.array([[[0, 0]]]))
1414

15-
def test_simple_cpu(self):
16-
a = torch.tensor([[[0, 0, 0], [1, 0, 0], [2, 0, 0]]]).to(torch.float)
17-
b = torch.tensor([[[0, 0, 0]]]).to(torch.float)
18-
npt.assert_array_equal(ball_query_dense(1, 2, a, b).detach().numpy(), np.array([[[0, 0]]]))
15+
def test_larger_gpu(self):
16+
a = torch.randn(32, 4096, 3).to(torch.float).cuda()
17+
idx = ball_query(1, 64, a, a).detach().cpu().numpy()
18+
self.assertGreaterEqual(idx.min(),0)
1919

20-
def test_cpu_gpu_equality(self):
21-
a = torch.randn(5, 1000, 3)
22-
npt.assert_array_equal(ball_query_dense(0.1, 17, a, a).detach().numpy(),
23-
ball_query_dense(0.1, 17, a.cuda(), a.cuda()).cpu().detach().numpy())
20+
# def test_simple_cpu(self):
21+
# a = torch.tensor([[[0, 0, 0], [1, 0, 0], [2, 0, 0]]]).to(torch.float)
22+
# b = torch.tensor([[[0, 0, 0]]]).to(torch.float)
23+
# npt.assert_array_equal(ball_query(1, 2, a, b).detach().numpy(), np.array([[[0, 0]]]))
24+
25+
# def test_cpu_gpu_equality(self):
26+
# a = torch.randn(5, 1000, 3)
27+
# npt.assert_array_equal(ball_query(0.1, 17, a, a).detach().numpy(),
28+
# ball_query(0.1, 17, a.cuda(), a.cuda()).cpu().detach().numpy())
29+
30+
def test_partial_gpu(self):
31+
x = torch.tensor([[10, 0, 0], [0.1, 0, 0], [10, 0, 0], [0.1, 0, 0]]).to(torch.float).cuda()
32+
y = torch.tensor([[0, 0, 0]]).to(torch.float).cuda()
33+
batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long().cuda()
34+
batch_y = torch.from_numpy(np.asarray([0])).long().cuda()
35+
36+
batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long().cuda()
37+
batch_y = torch.from_numpy(np.asarray([0])).long().cuda()
38+
39+
idx, dist2 = ball_query(1., 2, x, y, mode="PARTIAL_DENSE", batch_x=batch_x, batch_y=batch_y)
40+
41+
idx = idx.detach().cpu().numpy()
42+
dist2 = dist2.detach().cpu().numpy()
43+
44+
idx_answer = np.asarray([[1, 4]])
45+
dist2_answer = np.asarray([[ 0.0100, -1.0000]]).astype(np.float32)
46+
47+
npt.assert_array_almost_equal(idx, idx_answer)
48+
npt.assert_array_almost_equal(dist2, dist2_answer)
2449

2550

2651
if __name__ == "__main__":

test/test_ballquerry_partial.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

torch_points/torchpoints.py

Lines changed: 23 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from torch.autograd import Function
33
import torch.nn as nn
44
import sys
5+
from typing import Optional
56

67
import torch_points.points_cpu as tpcpu
78

@@ -268,27 +269,6 @@ def backward(ctx, a=None):
268269
return None, None, None, None
269270

270271

271-
def ball_query_dense(radius, nsample, xyz, new_xyz):
272-
r"""
273-
Parameters
274-
----------
275-
radius : float
276-
radius of the balls
277-
nsample : int
278-
maximum number of features in the balls
279-
xyz : torch.Tensor
280-
(B, N, 3) xyz coordinates of the features
281-
new_xyz : torch.Tensor
282-
(B, npoint, 3) centers of the ball query
283-
284-
Returns
285-
-------
286-
torch.Tensor
287-
(B, npoint, nsample) tensor with the indicies of the features that form the query balls
288-
"""
289-
return BallQueryDense.apply(radius, nsample, xyz, new_xyz)
290-
291-
292272
class BallQueryPartialDense(Function):
293273
@staticmethod
294274
def forward(ctx, radius, nsample, x, y, batch_x, batch_y):
@@ -306,46 +286,41 @@ def backward(ctx, a=None):
306286
return None, None, None, None
307287

308288

309-
def ball_query_partial_dense(radius, nsample, x, y, batch_x, batch_y):
310-
r"""
311-
Parameters
312-
----------
313-
radius : float
314-
radius of the balls
315-
nsample : int
316-
maximum number of features in the balls
317-
x : torch.Tensor
318-
(M, 3) xyz coordinates of the features (The neighbours are going to be looked for there)
319-
y : torch.Tensor
320-
(N, npoint, 3) centers of the ball query
321-
batch_x : torch.Tensor
322-
(M, ) Contains indexes to indicate within batch it belongs to.
323-
batch_y : torch.Tensor
324-
(N, ) Contains indexes to indicate within batch it belongs to
289+
def ball_query(radius: float, nsample: int, x: torch.Tensor, y: torch.Tensor, mode: Optional[str] = 'dense',
290+
batch_x: Optional[torch.tensor] = None, batch_y: Optional[torch.tensor] = None) -> torch.Tensor:
291+
"""
292+
Arguments:
293+
radius {float} -- radius of the balls
294+
nsample {int} -- maximum number of features in the balls
295+
x {torch.Tensor} --
296+
(M, 3) [partial_dense] or (B, M, 3) [dense] xyz coordinates of the features
297+
y {torch.Tensor} --
298+
(npoint, 3) [partial_dense] or or (B, npoint, 3) [dense] centers of the ball query
299+
mode {str} -- switch between "dense" or "partial_dense" data layout
325300
326-
Returns
327-
-------
328-
torch.Tensor
329-
idx: (N, nsample) Default value: N. It contains the indexes of the element within y at radius distance to x
330-
dist2: (N, nsample) Default value: -1. It contains the square distances of the element within y at radius distance to x
331-
"""
332-
return BallQueryPartialDense.apply(radius, nsample, x, y, batch_x, batch_y)
301+
Keyword Arguments:
302+
batch_x {Optional[torch.tensor]} -- (M, ) Contains indexes to indicate within batch it belongs to.
303+
batch_y {Optional[torch.tensor]} -- (N, ) Contains indexes to indicate within batch it belongs to
333304
334305
335-
def ball_query(radius: float, nsample: int, x, y, batch_x=None, batch_y=None, mode=None):
306+
Returns:
307+
[type] -- [description]
308+
"""
336309
if mode is None:
337-
raise Exception('The mode should be defined within ["PARTIAL_DENSE | DENSE"]')
310+
raise Exception('The mode should be defined within ["partial_dense | dense"]')
338311

339312
if mode.lower() == "partial_dense":
340313
if (batch_x is None) or (batch_y is None):
341314
raise Exception('batch_x and batch_y should be provided')
342315
assert x.size(0) == batch_x.size(0)
343316
assert y.size(0) == batch_y.size(0)
344-
return ball_query_partial_dense(radius, nsample, x, y, batch_x, batch_y)
317+
assert x.dim() == 2
318+
return BallQueryPartialDense.apply(radius, nsample, x, y, batch_x, batch_y)
345319

346320
elif mode.lower() == "dense":
347321
if (batch_x is not None) or (batch_y is not None):
348322
raise Exception('batch_x and batch_y should not be provided')
349-
return ball_query_dense(radius, nsample, x, y)
323+
assert x.dim() == 3
324+
return BallQueryDense.apply(radius, nsample, x, y)
350325
else:
351326
raise Exception('unrecognized mode {}'.format(mode))

0 commit comments

Comments
 (0)