Skip to content

Commit fdc285a

Browse files
Merge pull request #9 from nicolas-chaulet/ball_query_inverse
'reverse search'
2 parents 01407c7 + 01e112d commit fdc285a

File tree

4 files changed

+19
-19
lines changed

4 files changed

+19
-19
lines changed

cuda/src/ball_query.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ at::Tensor ball_query_dense(at::Tensor new_xyz, at::Tensor xyz, const float radi
2828
CHECK_CUDA(xyz);
2929
}
3030

31-
at::Tensor idx =
32-
torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample},
31+
at::Tensor idx = torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample},
3332
at::device(new_xyz.device()).dtype(at::ScalarType::Int));
3433

3534
if (new_xyz.type().is_cuda()) {
@@ -67,11 +66,11 @@ std::pair<at::Tensor, at::Tensor> ball_query_partial_dense(at::Tensor x,
6766
CHECK_CUDA(batch_y);
6867
}
6968

70-
at::Tensor idx = torch::full({x.size(0), nsample}, y.size(0),
71-
at::device(x.device()).dtype(at::ScalarType::Long));
69+
at::Tensor idx = torch::full({y.size(0), nsample}, x.size(0),
70+
at::device(y.device()).dtype(at::ScalarType::Long));
7271

73-
at::Tensor dist = torch::full({x.size(0), nsample}, -1,
74-
at::device(x.device()).dtype(at::ScalarType::Float));
72+
at::Tensor dist = torch::full({y.size(0), nsample}, -1,
73+
at::device(y.device()).dtype(at::ScalarType::Float));
7574

7675
cudaSetDevice(x.get_device());
7776
auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));

cuda/src/ball_query_gpu.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ __global__ void query_ball_point_kernel_partial_dense(int size_x,
7575
(x[n_x * 3 + d] - y[n_y * 3 + d]);
7676
}
7777
if(dist <= radius2){
78-
idx_out[n_x * nsample + count] = n_y;
79-
dist_out[n_x * nsample + count] = dist;
78+
idx_out[n_y * nsample + count] = n_x;
79+
dist_out[n_y * nsample + count] = dist;
8080
count++;
8181
}
8282
if(count >= nsample){

test/test_ballquerry_partial.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ def test_simple_gpu(self):
2020
idx = idx.detach().cpu().numpy()
2121
dist2 = dist2.detach().cpu().numpy()
2222

23-
idx_answer = np.asarray([[1, 1], [0, 1], [1, 1], [1, 1]])
24-
dist2_answer = np.asarray([[-1, -1], [0.01, -1], [-1, -1], [-1, -1]]).astype(np.float32)
23+
idx_answer = np.asarray([[1, 4]])
24+
dist2_answer = np.asarray([[ 0.0100, -1.0000]]).astype(np.float32)
2525

2626
npt.assert_array_almost_equal(idx, idx_answer)
2727
npt.assert_array_almost_equal(dist2, dist2_answer)

torch_points/torchpoints.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import torch_points.points_cuda as tpcuda
1010

1111

12-
1312
class FurthestPointSampling(Function):
1413
@staticmethod
1514
def forward(ctx, xyz, npoint):
@@ -289,15 +288,16 @@ def ball_query_dense(radius, nsample, xyz, new_xyz):
289288
"""
290289
return BallQueryDense.apply(radius, nsample, xyz, new_xyz)
291290

291+
292292
class BallQueryPartialDense(Function):
293293
@staticmethod
294294
def forward(ctx, radius, nsample, x, y, batch_x, batch_y):
295295
# type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
296296
if x.is_cuda:
297-
return tpcuda.ball_query_partial_dense(x, y,
298-
batch_x,
299-
batch_y,
300-
radius, nsample)
297+
return tpcuda.ball_query_partial_dense(x, y,
298+
batch_x,
299+
batch_y,
300+
radius, nsample)
301301
else:
302302
raise NotImplementedError
303303

@@ -315,7 +315,7 @@ def ball_query_partial_dense(radius, nsample, x, y, batch_x, batch_y):
315315
nsample : int
316316
maximum number of features in the balls
317317
x : torch.Tensor
318-
(M, 3) xyz coordinates of the features
318+
(M, 3) xyz coordinates of the features (The neighbours are going to be looked for there)
319319
y : torch.Tensor
320320
(N, npoint, 3) centers of the ball query
321321
batch_x : torch.Tensor
@@ -326,11 +326,12 @@ def ball_query_partial_dense(radius, nsample, x, y, batch_x, batch_y):
326326
Returns
327327
-------
328328
torch.Tensor
329-
idx: (M, nsample) Default value: N. It contains the indexes of the element within y at radius distance to x
330-
dist2: (M, nsample) Default value: -1. It contains the square distances of the element within y at radius distance to x
329+
idx: (N, nsample) Default value: N. It contains the indexes of the element within y at radius distance to x
330+
dist2: (N, nsample) Default value: -1. It contains the square distances of the element within y at radius distance to x
331331
"""
332332
return BallQueryPartialDense.apply(radius, nsample, x, y, batch_x, batch_y)
333333

334+
334335
def ball_query(radius: float, nsample: int, x, y, batch_x=None, batch_y=None, mode=None):
335336
if mode is None:
336337
raise Exception('The mode should be defined within ["PARTIAL_DENSE | DENSE"]')
@@ -347,4 +348,4 @@ def ball_query(radius: float, nsample: int, x, y, batch_x=None, batch_y=None, mo
347348
raise Exception('batch_x and batch_y should not be provided')
348349
return ball_query_dense(radius, nsample, x, y)
349350
else:
350-
raise Exception('unrecognized mode {}'.format(mode))
351+
raise Exception('unrecognized mode {}'.format(mode))

0 commit comments

Comments
 (0)