33#include " utils.h"
44
55void query_ball_point_kernel_dense_wrapper (int b, int n, int m, float radius, int nsample,
6- const float * new_xyz, const float * xyz, int * idx,
6+ const float * new_xyz, const float * xyz, long * idx,
77 float * dist_out);
88
99void query_ball_point_kernel_partial_wrapper (long batch_size, int size_x, int size_y, float radius,
@@ -25,15 +25,15 @@ std::pair<at::Tensor, at::Tensor> ball_query_dense(at::Tensor new_xyz, at::Tenso
2525 }
2626
2727 at::Tensor idx = torch::zeros ({new_xyz.size (0 ), new_xyz.size (1 ), nsample},
28- at::device (new_xyz.device ()).dtype (at::ScalarType::Int ));
28+ at::device (new_xyz.device ()).dtype (at::ScalarType::Long ));
2929 at::Tensor dist = torch::full ({new_xyz.size (0 ), new_xyz.size (1 ), nsample}, -1 ,
3030 at::device (new_xyz.device ()).dtype (at::ScalarType::Float));
3131
3232 if (new_xyz.type ().is_cuda ())
3333 {
3434 query_ball_point_kernel_dense_wrapper (
3535 xyz.size (0 ), xyz.size (1 ), new_xyz.size (1 ), radius, nsample, new_xyz.DATA_PTR <float >(),
36- xyz.DATA_PTR <float >(), idx.DATA_PTR <int >(), dist.DATA_PTR <float >());
36+ xyz.DATA_PTR <float >(), idx.DATA_PTR <long >(), dist.DATA_PTR <float >());
3737 }
3838 else
3939 {
0 commit comments