66
77// input: new_xyz(b, m, 3) xyz(b, n, 3)
88// output: idx(b, m, nsample)
9- __global__ void query_ball_point_kernel (int b, int n, int m, float radius,
10- int nsample,
11- const float *__restrict__ new_xyz,
12- const float *__restrict__ xyz,
13- int *__restrict__ idx) {
9+ __global__ void query_ball_point_kernel_dense (int b, int n, int m, float radius,
10+ int nsample,
11+ const float *__restrict__ new_xyz,
12+ const float *__restrict__ xyz,
13+ int *__restrict__ idx_out) {
14+
1415 int batch_index = blockIdx .x ;
1516 xyz += batch_index * n * 3 ;
1617 new_xyz += batch_index * m * 3 ;
17- idx += m * nsample * batch_index;
18+ idx_out += m * nsample * batch_index;
1819
1920 int index = threadIdx .x ;
2021 int stride = blockDim .x ;
@@ -33,22 +34,83 @@ __global__ void query_ball_point_kernel(int b, int n, int m, float radius,
3334 if (d2 < radius2) {
3435 if (cnt == 0 ) {
3536 for (int l = 0 ; l < nsample; ++l) {
36- idx [j * nsample + l] = k;
37+ idx_out [j * nsample + l] = k;
3738 }
3839 }
39- idx [j * nsample + cnt] = k;
40+ idx_out [j * nsample + cnt] = k;
4041 ++cnt;
4142 }
4243 }
4344 }
4445}
4546
46- void query_ball_point_kernel_wrapper (int b, int n, int m, float radius,
47- int nsample, const float *new_xyz,
48- const float *xyz, int *idx) {
47+ __global__ void query_ball_point_kernel_partial_dense (int size_x,
48+ int size_y,
49+ float radius,
50+ int nsample,
51+ const float *__restrict__ x,
52+ const float *__restrict__ y,
53+ const long *__restrict__ batch_x,
54+ const long *__restrict__ batch_y,
55+ int64_t *__restrict__ idx_out,
56+ float * __restrict__ dist_out) {
57+
58+ // taken from https://github.com/rusty1s/pytorch_cluster/blob/master/cuda/radius_kernel.cu
59+ const ptrdiff_t batch_idx = blockIdx .x ;
60+ const ptrdiff_t idx = threadIdx .x ;
61+
62+ const ptrdiff_t start_idx_x = batch_x[batch_idx];
63+ const ptrdiff_t end_idx_x = batch_x[batch_idx + 1 ];
64+
65+ const ptrdiff_t start_idx_y = batch_y[batch_idx];
66+ const ptrdiff_t end_idx_y = batch_y[batch_idx + 1 ];
67+ float radius2 = radius * radius;
68+
69+ for (ptrdiff_t n_x = start_idx_x + idx; n_x < end_idx_x; n_x += TOTAL_THREADS) {
70+ int64_t count = 0 ;
71+ for (ptrdiff_t n_y = start_idx_y; n_y < end_idx_y; n_y++) {
72+ float dist = 0 ;
73+ for (ptrdiff_t d = 0 ; d < 3 ; d++) {
74+ dist += (x[n_x * 3 + d] - y[n_y * 3 + d]) *
75+ (x[n_x * 3 + d] - y[n_y * 3 + d]);
76+ }
77+ if (dist <= radius2){
78+ idx_out[n_x * nsample + count] = n_y;
79+ dist_out[n_x * nsample + count] = dist;
80+ count++;
81+ }
82+ if (count >= nsample){
83+ break ;
84+ }
85+ }
86+ }
87+ }
88+
89+ void query_ball_point_kernel_dense_wrapper (int b, int n, int m, float radius,
90+ int nsample, const float *new_xyz,
91+ const float *xyz, int *idx) {
4992 cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
50- query_ball_point_kernel <<<b, opt_n_threads(m), 0 , stream>>> (
93+ query_ball_point_kernel_dense <<<b, opt_n_threads(m), 0 , stream>>> (
5194 b, n, m, radius, nsample, new_xyz, xyz, idx);
5295
5396 CUDA_CHECK_ERRORS ();
5497}
98+
99+ void query_ball_point_kernel_partial_wrapper (long batch_size,
100+ int size_x,
101+ int size_y,
102+ float radius,
103+ int nsample,
104+ const float *x,
105+ const float *y,
106+ const long *batch_x,
107+ const long *batch_y,
108+ int64_t *idx_out,
109+ float *dist_out) {
110+
111+ query_ball_point_kernel_partial_dense<<<batch_size, TOTAL_THREADS>>> (
112+ size_x, size_y, radius, nsample, x, y,
113+ batch_x, batch_y, idx_out, dist_out);
114+
115+ CUDA_CHECK_ERRORS ();
116+ }
0 commit comments