Skip to content

Commit 2e02691

Browse files
committed
'still a bug'
1 parent bb09dc8 commit 2e02691

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

cuda/src/ball_query_gpu.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,17 @@ __global__ void query_ball_point_kernel_partial_dense(int size_x,
6868
const ptrdiff_t end_idx_y = batch_y[batch_idx + 1];
6969
float radius2 = radius * radius;
7070

71-
for (ptrdiff_t n_x = start_idx_x + idx; n_x < end_idx_x + 1; n_x += THREADS) {
71+
for (ptrdiff_t n_x = start_idx_x + idx; n_x < end_idx_x; n_x += THREADS) {
72+
printf("n_x: %d \n", n_x);
73+
7274
int64_t count = 0;
73-
for (ptrdiff_t n_y = start_idx_y; n_y < end_idx_y + 1; n_y++) {
75+
for (ptrdiff_t n_y = start_idx_y; n_y < end_idx_y; n_y++) {
7476
float dist = 0;
7577
for (ptrdiff_t d = 0; d < 3; d++) {
7678
dist += (x[n_x * 3 + d] - y[n_y * 3 + d]) *
7779
(x[n_x * 3 + d] - y[n_y * 3 + d]);
7880
}
79-
printf("Hello from (%d, %d) block, %d, thread %d\n", n_x, n_y, blockIdx.x, threadIdx.x);
81+
printf("n_x: %d, n_y: %d \n", n_x, n_y);
8082
if(dist <= radius2){
8183
idx_out[n_x * nsample + count] = n_y;
8284
dist_out[n_x * nsample + count] = dist;

test/test_ballquerry_partial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def test_simple_gpu(self):
1818
batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long().cuda()
1919
batch_y = torch.from_numpy(np.asarray([0])).long().cuda()
2020

21-
idx, dist2 = ball_query(1, 2, x, y, batch_x, batch_y, mode="PARTIAL_DENSE")
21+
idx, dist2 = ball_query(1., 2, x, y, batch_x, batch_y, mode="PARTIAL_DENSE")
2222

2323
idx = idx.detach().cpu().numpy()
2424
dist2 = dist2.detach().cpu().numpy()

torch_points/torchpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def ball_query_partial_dense(radius, nsample, x, y, batch_x, batch_y):
330330
"""
331331
return BallQueryPartialDense.apply(radius, nsample, x, y, batch_x, batch_y)
332332

333-
def ball_query(radius, nsample, x, y, batch_x=None, batch_y=None, mode=None):
333+
def ball_query(radius: float, nsample: int, x, y, batch_x=None, batch_y=None, mode=None):
334334
if mode is None:
335335
raise Exception('The mode should be defined within ["PARTIAL_DENSE | DENSE"]')
336336

0 commit comments

Comments
 (0)