Skip to content

Commit edc7c41

Browse files
index as long
1 parent 6dafff9 commit edc7c41

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

cuda/src/ball_query.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#include "utils.h"
44

55
void query_ball_point_kernel_dense_wrapper(int b, int n, int m, float radius, int nsample,
6-
const float* new_xyz, const float* xyz, int* idx,
6+
const float* new_xyz, const float* xyz, long* idx,
77
float* dist_out);
88

99
void query_ball_point_kernel_partial_wrapper(long batch_size, int size_x, int size_y, float radius,
@@ -25,15 +25,15 @@ std::pair<at::Tensor, at::Tensor> ball_query_dense(at::Tensor new_xyz, at::Tenso
2525
}
2626

2727
at::Tensor idx = torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample},
28-
at::device(new_xyz.device()).dtype(at::ScalarType::Int));
28+
at::device(new_xyz.device()).dtype(at::ScalarType::Long));
2929
at::Tensor dist = torch::full({new_xyz.size(0), new_xyz.size(1), nsample}, -1,
3030
at::device(new_xyz.device()).dtype(at::ScalarType::Float));
3131

3232
if (new_xyz.type().is_cuda())
3333
{
3434
query_ball_point_kernel_dense_wrapper(
3535
xyz.size(0), xyz.size(1), new_xyz.size(1), radius, nsample, new_xyz.DATA_PTR<float>(),
36-
xyz.DATA_PTR<float>(), idx.DATA_PTR<int>(), dist.DATA_PTR<float>());
36+
xyz.DATA_PTR<float>(), idx.DATA_PTR<long>(), dist.DATA_PTR<float>());
3737
}
3838
else
3939
{

cuda/src/ball_query_gpu.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
__global__ void query_ball_point_kernel_dense(int b, int n, int m, float radius, int nsample,
1010
const float* __restrict__ new_xyz,
1111
const float* __restrict__ xyz,
12-
int* __restrict__ idx_out,
12+
long* __restrict__ idx_out,
1313
float* __restrict__ dist_out)
1414
{
1515
int batch_index = blockIdx.x;
@@ -93,7 +93,7 @@ __global__ void query_ball_point_kernel_partial_dense(
9393
}
9494

9595
void query_ball_point_kernel_dense_wrapper(int b, int n, int m, float radius, int nsample,
96-
const float* new_xyz, const float* xyz, int* idx,float* dist_out)
96+
const float* new_xyz, const float* xyz, long* idx,float* dist_out)
9797
{
9898
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
9999
query_ball_point_kernel_dense<<<b, opt_n_threads(m), 0, stream>>>(b, n, m, radius, nsample,

0 commit comments

Comments
 (0)