@@ -15,7 +15,7 @@ std::pair<at::Tensor, at::Tensor> ball_query(at::Tensor support, at::Tensor quer
1515
1616 at::Tensor out;
1717 at::Tensor out_dists;
18- std::vector<long > neighbors_indices (query.size (0 ), 0 );
18+ std::vector<int64_t > neighbors_indices (query.size (0 ), 0 );
1919 std::vector<float > neighbors_dists (query.size (0 ), -1 );
2020
2121 auto options = torch::TensorOptions ().dtype (torch::kLong ).device (torch::kCPU );
@@ -34,7 +34,7 @@ std::pair<at::Tensor, at::Tensor> ball_query(at::Tensor support, at::Tensor quer
3434 neighbors_dists, radius, max_num, mode, sorted);
3535 });
3636 auto neighbors_dists_ptr = neighbors_dists.data ();
37- long * neighbors_indices_ptr = neighbors_indices.data ();
37+ int64_t * neighbors_indices_ptr = neighbors_indices.data ();
3838 if (mode == 0 )
3939 {
4040 out =
@@ -73,7 +73,7 @@ std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor support, at::Tenso
7373 at::Tensor idx;
7474
7575 at::Tensor dist;
76- std::vector<long > neighbors_indices;
76+ std::vector<int64_t > neighbors_indices;
7777 std::vector<float > neighbors_dists;
7878
7979 auto options = torch::TensorOptions ().dtype (torch::kLong ).device (torch::kCPU );
@@ -91,10 +91,11 @@ std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor support, at::Tenso
9191 query_batch = at::cat ({at::zeros (1 , query_batch.options ()), query_batch.cumsum (0 )}, 0 );
9292 support_batch = degree (support_batch, batch_size);
9393 support_batch = at::cat ({at::zeros (1 , support_batch.options ()), support_batch.cumsum (0 )}, 0 );
94- std::vector<long > query_batch_stl (query_batch.DATA_PTR <long >(),
95- query_batch.DATA_PTR <long >() + query_batch.numel ());
96- std::vector<long > support_batch_stl (support_batch.DATA_PTR <long >(),
97- support_batch.DATA_PTR <long >() + support_batch.numel ());
94+ std::vector<int64_t > query_batch_stl (query_batch.DATA_PTR <int64_t >(),
95+ query_batch.DATA_PTR <int64_t >() + query_batch.numel ());
96+ std::vector<int64_t > support_batch_stl (support_batch.DATA_PTR <int64_t >(),
97+ support_batch.DATA_PTR <int64_t >() +
98+ support_batch.numel ());
9899
99100 AT_DISPATCH_ALL_TYPES (query.scalar_type (), " batch_radius_search" , [&] {
100101 std::vector<scalar_t > queries_stl (query.DATA_PTR <scalar_t >(),
@@ -107,7 +108,7 @@ std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor support, at::Tenso
107108 neighbors_dists, radius, max_num, mode, sorted);
108109 });
109110 auto neighbors_dists_ptr = neighbors_dists.data ();
110- long * neighbors_indices_ptr = neighbors_indices.data ();
111+ int64_t * neighbors_indices_ptr = neighbors_indices.data ();
111112
112113 if (mode == 0 )
113114 {
0 commit comments