Skip to content

Commit f47eabb

Browse files
Ramp up the tests
1 parent 03e61c4 commit f47eabb

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

cpu/src/neighbors.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,6 @@ int nanoflann_neighbors(vector<scalar_t>& queries,
119119
}
120120
i0++;
121121
}
122-
123-
124122
}
125123
return max_count;
126124
}

test/test_ballquerry.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,17 @@ def test_larger_gpu(self):
3030

3131
@run_if_cuda
3232
def test_cpu_gpu_equality(self):
33-
a = torch.randn(2, 10, 3)
34-
b = torch.randn(2, 5, 3)
35-
res_cpu = ball_query(1, 17, a, b).detach().numpy()
36-
res_cuda = ball_query(1, 17, a.cuda(), b.cuda()).cpu().detach().numpy()
33+
a = torch.randn(5, 1000, 3)
34+
b = torch.randn(5, 500, 3)
35+
res_cpu = ball_query(1, 500, a, b).detach().numpy()
36+
res_cuda = ball_query(1, 500, a.cuda(), b.cuda()).cpu().detach().numpy()
37+
for i in range(b.shape[0]):
38+
for j in range(b.shape[1]):
39+
# Because it is not necessary the same order
40+
assert set(res_cpu[i][j]) == set(res_cuda[i][j])
41+
42+
res_cpu = ball_query(0.01, 500, a, b).detach().numpy()
43+
res_cuda = ball_query(0.01, 500, a.cuda(), b.cuda()).cpu().detach().numpy()
3744
for i in range(b.shape[0]):
3845
for j in range(b.shape[1]):
3946
# Because it is not necessary the same order

0 commit comments

Comments
 (0)