99__global__ void query_ball_point_kernel_dense (int b, int n, int m, float radius, int nsample,
1010 const float * __restrict__ new_xyz,
1111 const float * __restrict__ xyz,
12- int64_t * __restrict__ idx_out,
12+ int64_t * __restrict__ idx_out,
1313 float * __restrict__ dist_out)
1414{
1515 int batch_index = blockIdx .x ;
@@ -51,10 +51,13 @@ __global__ void query_ball_point_kernel_dense(int b, int n, int m, float radius,
5151 }
5252}
5353
54- __global__ void query_ball_point_kernel_partial_dense (
55- int size_x, int size_y, float radius, int nsample, const float * __restrict__ x,
56- const float * __restrict__ y, const int64_t * __restrict__ batch_x, const int64_t * __restrict__ batch_y,
57- int64_t * __restrict__ idx_out, float * __restrict__ dist_out)
54+ __global__ void query_ball_point_kernel_partial_dense (int size_x, int size_y, float radius,
55+ int nsample, const float * __restrict__ x,
56+ const float * __restrict__ y,
57+ const int64_t * __restrict__ batch_x,
58+ const int64_t * __restrict__ batch_y,
59+ int64_t * __restrict__ idx_out,
60+ float * __restrict__ dist_out)
5861{
5962 // taken from
6063 // https://github.com/rusty1s/pytorch_cluster/blob/master/cuda/radius_kernel.cu
@@ -67,7 +70,7 @@ __global__ void query_ball_point_kernel_partial_dense(
6770 const ptrdiff_t end_idx_y = batch_y[batch_idx + 1 ];
6871 float radius2 = radius * radius;
6972
70- for (ptrdiff_t n_y = start_idx_y + threadIdx .x ; n_y < end_idx_y; n_y += blockDim .x )
73+ for (ptrdiff_t n_y = start_idx_y + threadIdx .x ; n_y < end_idx_y; n_y += blockDim .x )
7174 {
7275 int64_t count = 0 ;
7376 for (ptrdiff_t n_x = start_idx_x; n_x < end_idx_x; n_x++)
@@ -92,19 +95,21 @@ __global__ void query_ball_point_kernel_partial_dense(
9295}
9396
9497void query_ball_point_kernel_dense_wrapper (int b, int n, int m, float radius, int nsample,
95- const float * new_xyz, const float * xyz, int64_t * idx,float * dist_out)
98+ const float * new_xyz, const float * xyz, int64_t * idx,
99+ float * dist_out)
96100{
97101 cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
98102 query_ball_point_kernel_dense<<<b, opt_n_threads(m), 0 , stream>>> (b, n, m, radius, nsample,
99- new_xyz, xyz, idx,dist_out);
103+ new_xyz, xyz, idx, dist_out);
100104
101105 CUDA_CHECK_ERRORS ();
102106}
103107
104- void query_ball_point_kernel_partial_wrapper (int64_t batch_size, int size_x, int size_y, float radius,
105- int nsample, const float * x, const float * y,
106- const int64_t * batch_x, const int64_t * batch_y,
107- int64_t * idx_out, float * dist_out)
108+ void query_ball_point_kernel_partial_wrapper (int64_t batch_size, int size_x, int size_y,
109+ float radius, int nsample, const float * x,
110+ const float * y, const int64_t * batch_x,
111+ const int64_t * batch_y, int64_t * idx_out,
112+ float * dist_out)
108113{
109114 query_ball_point_kernel_partial_dense<<<batch_size, TOTAL_THREADS_SPARSE>>> (
110115 size_x, size_y, radius, nsample, x, y, batch_x, batch_y, idx_out, dist_out);
0 commit comments