@@ -59,87 +59,71 @@ std::pair<at::Tensor, at::Tensor> ball_query(at::Tensor query,
5959 return std::make_pair (out.clone (), out_dists.clone ());
6060}
6161
62- void cumsum (const vector<long >& batch, vector<long >& res){
63-
64- res.resize (batch[batch.size ()-1 ]-batch[0 ]+2 , 0 );
65- long ind = batch[0 ];
66- long incr = 1 ;
67- if (res.size () > 1 ){
68- for (int i=1 ; i < batch.size (); i++){
69- if (batch[i] == ind)
70- incr++;
71- else {
72- res[ind-batch[0 ]+1 ] = incr;
73- incr =1 ;
74- ind = batch[i];
75- }
76- }
77-
78- }
79- res[ind-batch[0 ]+1 ] = incr;
62+ at::Tensor degree (at::Tensor row, int64_t num_nodes) {
63+ auto zero = at::zeros (num_nodes, row.options ());
64+ auto one = at::ones (row.size (0 ), row.options ());
65+ return zero.scatter_add_ (0 , row, one);
8066}
8167
8268std::pair<at::Tensor, at::Tensor> batch_ball_query (at::Tensor query,
8369 at::Tensor support,
8470 at::Tensor query_batch,
8571 at::Tensor support_batch,
8672 float radius, int max_num, int mode) {
87- at::Tensor out;
88- at::Tensor out_dists;
89- auto data_qb = query_batch.DATA_PTR <long >();
90- auto data_sb = support_batch.DATA_PTR <long >();
91- std::vector<long > query_batch_stl = std::vector<long >(data_qb, data_qb+query_batch.size (0 ));
92- std::vector<long > cumsum_query_batch_stl;
93- cumsum (query_batch_stl, cumsum_query_batch_stl);
94-
95- std::vector<long > support_batch_stl = std::vector<long >(data_sb, data_sb+support_batch.size (0 ));
96- std::vector<long > cumsum_support_batch_stl;
97- cumsum (support_batch_stl, cumsum_support_batch_stl);
73+ at::Tensor idx;
9874
75+ at::Tensor dist;
9976 std::vector<long > neighbors_indices;
77+ std::vector<float > neighbors_dists;
10078
10179 auto options = torch::TensorOptions ().dtype (torch::kLong ).device (torch::kCPU );
10280 auto options_dist = torch::TensorOptions ().dtype (torch::kFloat32 ).device (torch::kCPU );
81+
10382 int max_count = 0 ;
104- std::vector<float > neighbors_dists;
83+ auto batch_access = query_batch.accessor <int64_t , 1 >();
84+ auto batch_size = batch_access[-1 ] + 1 ;
85+ query_batch = degree (query_batch, batch_size);
86+ query_batch = at::cat ({at::zeros (1 , query_batch.options ()), query_batch.cumsum (0 )}, 0 );
87+ support_batch = degree (support_batch, batch_size);
88+ support_batch = at::cat ({at::zeros (1 , support_batch.options ()), support_batch.cumsum (0 )}, 0 );
89+ std::vector<long > query_batch_stl (query_batch.DATA_PTR <long >(), query_batch.DATA_PTR <long >() + query_batch.numel ());
90+ std::vector<long > support_batch_stl (support_batch.DATA_PTR <long >(), support_batch.DATA_PTR <long >() + support_batch.numel ());
91+
10592 AT_DISPATCH_ALL_TYPES (query.scalar_type (), " batch_radius_search" , [&] {
10693
107- auto data_q = query.DATA_PTR <scalar_t >();
108- auto data_s = support.DATA_PTR <scalar_t >();
109- std::vector<scalar_t > queries_stl = std::vector<scalar_t >(data_q,
110- data_q + query.size (0 )*query.size (1 ));
111- std::vector<scalar_t > supports_stl = std::vector<scalar_t >(data_s,
112- data_s + support.size (0 )*support.size (1 ));
94+ std::vector<scalar_t > queries_stl (query.DATA_PTR <scalar_t >(), query.DATA_PTR <scalar_t >() + query.numel ());
95+ std::vector<scalar_t > supports_stl (support.DATA_PTR <scalar_t >(), support.DATA_PTR <scalar_t >() + support.numel ());
11396
11497
115- max_count = batch_nanoflann_neighbors<scalar_t >(queries_stl,
98+ max_count = batch_nanoflann_neighbors<scalar_t >(queries_stl,
11699 supports_stl,
117- cumsum_query_batch_stl ,
118- cumsum_support_batch_stl ,
100+ query_batch_stl ,
101+ support_batch_stl ,
119102 neighbors_indices,
120103 neighbors_dists,
121104 radius,
122105 max_num,
123106 mode);
124- });
125-
126- long * neighbors_indices_ptr = neighbors_indices.data ();
127- auto neighbors_dists_ptr = neighbors_dists.data ();
128107
129108
109+ });
110+ auto neighbors_dists_ptr = neighbors_dists.data ();
111+ long * neighbors_indices_ptr = neighbors_indices.data ();
130112 if (mode == 0 ){
131- out = torch::from_blob (neighbors_indices_ptr, {query.size (0 ), max_count}, options=options);
132- out_dists = torch::from_blob (neighbors_dists_ptr,
133- {query.size (0 ), max_count},
134- options=options_dist);
113+ idx = torch::from_blob (neighbors_indices_ptr, {query.size (0 ), max_count}, options=options);
114+ dist = torch::from_blob (neighbors_dists_ptr,
115+ {query.size (0 ), max_count},
116+ options=options_dist);
117+
135118 }
136- else if (mode == 1 ){
137- out = torch::from_blob (neighbors_indices_ptr, {(int )neighbors_indices.size ()/2 , 2 }, options=options);
138- out_dists = torch::from_blob (neighbors_dists_ptr,
139- {(int )neighbors_indices.size ()/2 , 1 },
140- options=options_dist);
119+ else if (mode ==1 ){
120+ idx = torch::from_blob (neighbors_indices_ptr, {(int )neighbors_indices.size ()/2 , 2 }, options=options);
121+ dist = torch::from_blob (neighbors_dists_ptr,
122+ {(int )neighbors_indices.size ()/2 , 1 },
123+ options=options_dist);
141124 }
142- return std::make_pair (out.clone (), out_dists.clone ());
125+ return std::make_pair (idx.clone (), dist.clone ());
126+
143127}
144128
145129
0 commit comments