Skip to content

Commit a5d6145

Browse files
Remove negative indexing
1 parent ee82d3d commit a5d6145

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

cpu/src/ball_query.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,13 @@ std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor support, at::Tenso
8484
auto options_dist = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU);
8585

8686
int max_count = 0;
87-
auto batch_access = query_batch.accessor<int64_t, 1>();
87+
auto q_batch_access = query_batch.accessor<int64_t, 1>();
88+
auto s_batch_access = support_batch.accessor<int64_t, 1>();
89+
90+
auto batch_size = q_batch_access[query_batch.size(0) - 1] + 1;
91+
TORCH_CHECK(batch_size == (s_batch_access[support_batch.size(0) - 1] + 1),
92+
"Both batches need to have the same number of samples.")
8893

89-
auto batch_size = batch_access[-1] + 1;
9094
query_batch = degree(query_batch, batch_size);
9195
query_batch = at::cat({at::zeros(1, query_batch.options()), query_batch.cumsum(0)}, 0);
9296
support_batch = degree(support_batch, batch_size);

test/test_ballquerry.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
import unittest
22
import torch
3-
from torch_points_kernels import ball_query
43
import numpy.testing as npt
54
import numpy as np
65
from sklearn.neighbors import KDTree
6+
import os
7+
import sys
8+
9+
ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")
10+
sys.path.insert(0, ROOT)
711

8-
from . import run_if_cuda
12+
from test import run_if_cuda
13+
from torch_points_kernels import ball_query
914

1015

1116
class TestBall(unittest.TestCase):
@@ -76,10 +81,10 @@ def test_simple_gpu(self):
7681
npt.assert_array_almost_equal(dist2, dist2_answer)
7782

7883
def test_simple_cpu(self):
79-
x = torch.tensor([[10, 0, 0], [0.1, 0, 0], [10, 0, 0], [0.1, 0, 0]]).to(torch.float)
84+
x = torch.tensor([[10, 0, 0], [0.1, 0, 0], [10, 0, 0], [10.1, 0, 0]]).to(torch.float)
8085
y = torch.tensor([[0, 0, 0]]).to(torch.float)
8186

82-
batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long()
87+
batch_x = torch.from_numpy(np.asarray([0, 0, 0, 0])).long()
8388
batch_y = torch.from_numpy(np.asarray([0])).long()
8489

8590
idx, dist2 = ball_query(1.0, 2, x, y, mode="PARTIAL_DENSE", batch_x=batch_x, batch_y=batch_y)
@@ -93,6 +98,17 @@ def test_simple_cpu(self):
9398
npt.assert_array_almost_equal(idx, idx_answer)
9499
npt.assert_array_almost_equal(dist2, dist2_answer)
95100

101+
102+
def test_breaks(self):
103+
x = torch.tensor([[10, 0, 0], [0.1, 0, 0], [10, 0, 0], [10.1, 0, 0]]).to(torch.float)
104+
y = torch.tensor([[0, 0, 0]]).to(torch.float)
105+
106+
batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long()
107+
batch_y = torch.from_numpy(np.asarray([0])).long()
108+
109+
with self.assertRaises(RuntimeError):
110+
idx, dist2 = ball_query(1.0, 2, x, y, mode="PARTIAL_DENSE", batch_x=batch_x, batch_y=batch_y)
111+
96112
def test_random_cpu(self):
97113
a = torch.randn(100, 3).to(torch.float)
98114
b = torch.randn(50, 3).to(torch.float)

0 commit comments

Comments
 (0)