99__global__ void query_ball_point_kernel_dense (int b, int n, int m, float radius, int nsample,
1010 const float * __restrict__ new_xyz,
1111 const float * __restrict__ xyz,
12- int64_t * __restrict__ idx_out,
12+ int64_t * __restrict__ idx_out,
1313 float * __restrict__ dist_out)
1414{
1515 int batch_index = blockIdx .x ;
@@ -51,15 +51,17 @@ __global__ void query_ball_point_kernel_dense(int b, int n, int m, float radius,
5151 }
5252}
5353
54- __global__ void query_ball_point_kernel_partial_dense (
55- int size_x, int size_y, float radius, int nsample, const float * __restrict__ x,
56- const float * __restrict__ y, const int64_t * __restrict__ batch_x, const int64_t * __restrict__ batch_y,
57- int64_t * __restrict__ idx_out, float * __restrict__ dist_out)
54+ __global__ void query_ball_point_kernel_partial_dense (int size_x, int size_y, float radius,
55+ int nsample, const float * __restrict__ x,
56+ const float * __restrict__ y,
57+ const int64_t * __restrict__ batch_x,
58+ const int64_t * __restrict__ batch_y,
59+ int64_t * __restrict__ idx_out,
60+ float * __restrict__ dist_out)
5861{
5962 // taken from
6063 // https://github.com/rusty1s/pytorch_cluster/blob/master/cuda/radius_kernel.cu
6164 const ptrdiff_t batch_idx = blockIdx .x ;
62- const ptrdiff_t idx = threadIdx .x ;
6365
6466 const ptrdiff_t start_idx_x = batch_x[batch_idx];
6567 const ptrdiff_t end_idx_x = batch_x[batch_idx + 1 ];
@@ -68,10 +70,10 @@ __global__ void query_ball_point_kernel_partial_dense(
6870 const ptrdiff_t end_idx_y = batch_y[batch_idx + 1 ];
6971 float radius2 = radius * radius;
7072
71- for (ptrdiff_t n_x = start_idx_x + idx; n_x < end_idx_x; n_x += TOTAL_THREADS_SPARSE )
73+ for (ptrdiff_t n_y = start_idx_y + threadIdx . x ; n_y < end_idx_y; n_y += blockDim . x )
7274 {
7375 int64_t count = 0 ;
74- for (ptrdiff_t n_y = start_idx_y; n_y < end_idx_y; n_y ++)
76+ for (ptrdiff_t n_x = start_idx_x; n_x < end_idx_x; n_x ++)
7577 {
7678 float dist = 0 ;
7779 for (ptrdiff_t d = 0 ; d < 3 ; d++)
@@ -93,19 +95,21 @@ __global__ void query_ball_point_kernel_partial_dense(
9395}
9496
9597void query_ball_point_kernel_dense_wrapper (int b, int n, int m, float radius, int nsample,
96- const float * new_xyz, const float * xyz, int64_t * idx,float * dist_out)
98+ const float * new_xyz, const float * xyz, int64_t * idx,
99+ float * dist_out)
97100{
98101 cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
99102 query_ball_point_kernel_dense<<<b, opt_n_threads(m), 0 , stream>>> (b, n, m, radius, nsample,
100- new_xyz, xyz, idx,dist_out);
103+ new_xyz, xyz, idx, dist_out);
101104
102105 CUDA_CHECK_ERRORS ();
103106}
104107
105- void query_ball_point_kernel_partial_wrapper (int64_t batch_size, int size_x, int size_y, float radius,
106- int nsample, const float * x, const float * y,
107- const int64_t * batch_x, const int64_t * batch_y,
108- int64_t * idx_out, float * dist_out)
108+ void query_ball_point_kernel_partial_wrapper (int64_t batch_size, int size_x, int size_y,
109+ float radius, int nsample, const float * x,
110+ const float * y, const int64_t * batch_x,
111+ const int64_t * batch_y, int64_t * idx_out,
112+ float * dist_out)
109113{
110114 query_ball_point_kernel_partial_dense<<<batch_size, TOTAL_THREADS_SPARSE>>> (
111115 size_x, size_y, radius, nsample, x, y, batch_x, batch_y, idx_out, dist_out);
0 commit comments