Skip to content

Commit de90ce4

Browse files
committed
'ball_query_2 is working'
1 parent 2e02691 commit de90ce4

File tree

3 files changed

+5
-13
lines changed

3 files changed

+5
-13
lines changed

cuda/src/ball_query.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ std::pair<at::Tensor, at::Tensor> ball_query_partial_dense(at::Tensor x,
6969

7070
at::Tensor idx = torch::full({x.size(0), nsample}, y.size(0),
7171
at::device(x.device()).dtype(at::ScalarType::Long));
72+
7273
at::Tensor dist = torch::full({x.size(0), nsample}, -1,
7374
at::device(x.device()).dtype(at::ScalarType::Float));
7475

@@ -78,15 +79,11 @@ std::pair<at::Tensor, at::Tensor> ball_query_partial_dense(at::Tensor x,
7879
cudaMemcpyDeviceToHost);
7980
auto batch_size = batch_sizes[0] + 1;
8081

81-
std::cout << batch_x << std::endl;
8282
batch_x = degree(batch_x, batch_size);
8383
batch_x = at::cat({at::zeros(1, batch_x.options()), batch_x.cumsum(0)}, 0);
8484
batch_y = degree(batch_y, batch_size);
8585
batch_y = at::cat({at::zeros(1, batch_y.options()), batch_y.cumsum(0)}, 0);
8686

87-
std::cout << batch_x << std::endl;
88-
std::cout << batch_y << std::endl;
89-
9087
if (x.type().is_cuda()) {
9188
query_ball_point_kernel_partial_wrapper(batch_size,
9289
x.size(0),

cuda/src/ball_query_gpu.cu

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,16 +69,13 @@ __global__ void query_ball_point_kernel_partial_dense(int size_x,
6969
float radius2 = radius * radius;
7070

7171
for (ptrdiff_t n_x = start_idx_x + idx; n_x < end_idx_x; n_x += THREADS) {
72-
printf("n_x: %d \n", n_x);
73-
7472
int64_t count = 0;
7573
for (ptrdiff_t n_y = start_idx_y; n_y < end_idx_y; n_y++) {
7674
float dist = 0;
7775
for (ptrdiff_t d = 0; d < 3; d++) {
7876
dist += (x[n_x * 3 + d] - y[n_y * 3 + d]) *
7977
(x[n_x * 3 + d] - y[n_y * 3 + d]);
8078
}
81-
printf("n_x: %d, n_y: %d \n", n_x, n_y);
8279
if(dist <= radius2){
8380
idx_out[n_x * nsample + count] = n_y;
8481
dist_out[n_x * nsample + count] = dist;

test/test_ballquerry_partial.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@ def test_simple_gpu(self):
1313
batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long().cuda()
1414
batch_y = torch.from_numpy(np.asarray([0])).long().cuda()
1515

16-
print(x.shape, y.shape, batch_x.shape, batch_y.shape)
17-
1816
batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long().cuda()
1917
batch_y = torch.from_numpy(np.asarray([0])).long().cuda()
2018

@@ -23,11 +21,11 @@ def test_simple_gpu(self):
2321
idx = idx.detach().cpu().numpy()
2422
dist2 = dist2.detach().cpu().numpy()
2523

26-
print(idx)
27-
print(dist2)
24+
idx_answer = np.asarray([[1, 1], [0, 1], [1, 1], [1, 1]])
25+
dist2_answer = np.asarray([[-1, -1], [0.01, -1], [-1, -1], [-1, -1]]).astype(np.float32)
2826

29-
#npt.assert_array_equal(idx, )
30-
#npt.assert_array_equal(dist2, np.asarray([[0.3, 0.1]]))
27+
npt.assert_array_almost_equal(idx, idx_answer)
28+
npt.assert_array_almost_equal(dist2, dist2_answer)
3129

3230

3331

0 commit comments

Comments
 (0)