Skip to content

Commit 66aac86

Browse files
Support thread number for both modes
1 parent c38528e commit 66aac86

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

cuda/include/cuda_utils.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,19 @@
1010

1111
#include <vector>
1212

13-
#define TOTAL_THREADS 512
13+
#define TOTAL_THREADS_DENSE 512
14+
#define TOTAL_THREADS_SPARSE 1024
1415

1516
inline int opt_n_threads(int work_size) {
1617
const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);
1718

18-
return max(min(1 << pow_2, TOTAL_THREADS), 1);
19+
return max(min(1 << pow_2, TOTAL_THREADS_DENSE), 1);
1920
}
2021

2122
inline dim3 opt_block_config(int x, int y) {
2223
const int x_threads = opt_n_threads(x);
2324
const int y_threads =
24-
max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1);
25+
max(min(opt_n_threads(y), TOTAL_THREADS_DENSE / x_threads), 1);
2526
dim3 block_config(x_threads, y_threads, 1);
2627

2728
return block_config;

cuda/src/ball_query_gpu.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ __global__ void query_ball_point_kernel_partial_dense(int size_x,
6666
const ptrdiff_t end_idx_y = batch_y[batch_idx + 1];
6767
float radius2 = radius * radius;
6868

69-
for (ptrdiff_t n_x = start_idx_x + idx; n_x < end_idx_x; n_x += TOTAL_THREADS) {
69+
for (ptrdiff_t n_x = start_idx_x + idx; n_x < end_idx_x; n_x += TOTAL_THREADS_SPARSE) {
7070
int64_t count = 0;
7171
for (ptrdiff_t n_y = start_idx_y; n_y < end_idx_y; n_y++) {
7272
float dist = 0;
@@ -108,7 +108,7 @@ void query_ball_point_kernel_partial_wrapper(long batch_size,
108108
int64_t *idx_out,
109109
float *dist_out) {
110110

111-
query_ball_point_kernel_partial_dense<<<batch_size, TOTAL_THREADS>>>(
111+
query_ball_point_kernel_partial_dense<<<batch_size, TOTAL_THREADS_SPARSE>>>(
112112
size_x, size_y, radius, nsample, x, y,
113113
batch_x, batch_y, idx_out, dist_out);
114114

0 commit comments

Comments
 (0)