Skip to content

Commit de2e845

Browse files
committed
Added Windows support
long type is ambiguous depending on operating system, whereas int64_t ensures consistent use of 64bit types across all platforms
1 parent 0220fa5 commit de2e845

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

cuda/src/ball_query.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
#include "utils.h"
44

55
void query_ball_point_kernel_dense_wrapper(int b, int n, int m, float radius, int nsample,
6-
const float* new_xyz, const float* xyz, long* idx,
6+
const float* new_xyz, const float* xyz, int64_t* idx,
77
float* dist_out);
88

9-
void query_ball_point_kernel_partial_wrapper(long batch_size, int size_x, int size_y, float radius,
9+
void query_ball_point_kernel_partial_wrapper(int64_t batch_size, int size_x, int size_y, float radius,
1010
int nsample, const float* x, const float* y,
11-
const long* batch_x, const long* batch_y,
12-
long* idx_out, float* dist_out);
11+
const int64_t* batch_x, const int64_t* batch_y,
12+
int64_t* idx_out, float* dist_out);
1313

1414
std::pair<at::Tensor, at::Tensor> ball_query_dense(at::Tensor new_xyz, at::Tensor xyz,
1515
const float radius, const int nsample)
@@ -29,7 +29,7 @@ std::pair<at::Tensor, at::Tensor> ball_query_dense(at::Tensor new_xyz, at::Tenso
2929

3030
query_ball_point_kernel_dense_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), radius,
3131
nsample, new_xyz.DATA_PTR<float>(), xyz.DATA_PTR<float>(),
32-
idx.DATA_PTR<long>(), dist.DATA_PTR<float>());
32+
idx.DATA_PTR<int64_t>(), dist.DATA_PTR<float>());
3333

3434
return std::make_pair(idx, dist);
3535
}
@@ -73,8 +73,8 @@ std::pair<at::Tensor, at::Tensor> ball_query_partial_dense(at::Tensor x, at::Ten
7373

7474
query_ball_point_kernel_partial_wrapper(batch_size, x.size(0), y.size(0), radius, nsample,
7575
x.DATA_PTR<float>(), y.DATA_PTR<float>(),
76-
batch_x.DATA_PTR<long>(), batch_y.DATA_PTR<long>(),
77-
idx.DATA_PTR<long>(), dist.DATA_PTR<float>());
76+
batch_x.DATA_PTR<int64_t>(), batch_y.DATA_PTR<int64_t>(),
77+
idx.DATA_PTR<int64_t>(), dist.DATA_PTR<float>());
7878

7979
return std::make_pair(idx, dist);
8080
}

cuda/src/ball_query_gpu.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
__global__ void query_ball_point_kernel_dense(int b, int n, int m, float radius, int nsample,
1010
const float* __restrict__ new_xyz,
1111
const float* __restrict__ xyz,
12-
long* __restrict__ idx_out,
12+
int64_t* __restrict__ idx_out,
1313
float* __restrict__ dist_out)
1414
{
1515
int batch_index = blockIdx.x;
@@ -53,7 +53,7 @@ __global__ void query_ball_point_kernel_dense(int b, int n, int m, float radius,
5353

5454
__global__ void query_ball_point_kernel_partial_dense(
5555
int size_x, int size_y, float radius, int nsample, const float* __restrict__ x,
56-
const float* __restrict__ y, const long* __restrict__ batch_x, const long* __restrict__ batch_y,
56+
const float* __restrict__ y, const int64_t* __restrict__ batch_x, const int64_t* __restrict__ batch_y,
5757
int64_t* __restrict__ idx_out, float* __restrict__ dist_out)
5858
{
5959
// taken from
@@ -93,7 +93,7 @@ __global__ void query_ball_point_kernel_partial_dense(
9393
}
9494

9595
void query_ball_point_kernel_dense_wrapper(int b, int n, int m, float radius, int nsample,
96-
const float* new_xyz, const float* xyz, long* idx,float* dist_out)
96+
const float* new_xyz, const float* xyz, int64_t* idx,float* dist_out)
9797
{
9898
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
9999
query_ball_point_kernel_dense<<<b, opt_n_threads(m), 0, stream>>>(b, n, m, radius, nsample,
@@ -102,9 +102,9 @@ void query_ball_point_kernel_dense_wrapper(int b, int n, int m, float radius, in
102102
CUDA_CHECK_ERRORS();
103103
}
104104

105-
void query_ball_point_kernel_partial_wrapper(long batch_size, int size_x, int size_y, float radius,
105+
void query_ball_point_kernel_partial_wrapper(int64_t batch_size, int size_x, int size_y, float radius,
106106
int nsample, const float* x, const float* y,
107-
const long* batch_x, const long* batch_y,
107+
const int64_t* batch_x, const int64_t* batch_y,
108108
int64_t* idx_out, float* dist_out)
109109
{
110110
query_ball_point_kernel_partial_dense<<<batch_size, TOTAL_THREADS_SPARSE>>>(

0 commit comments

Comments
 (0)