@@ -146,33 +146,30 @@ def grouping_operation(features, idx):
146146 return grouped_features .reshape (idx .shape [0 ], features .shape [1 ], idx .shape [1 ], idx .shape [2 ])
147147
148148
149- class BallQueryDense (Function ):
150- @staticmethod
151- def forward (ctx , radius , nsample , xyz , new_xyz , batch_xyz = None , batch_new_xyz = None ):
152- # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
153- if new_xyz .is_cuda :
154- return tpcuda .ball_query_dense (new_xyz , xyz , radius , nsample )
155- else :
156- return tpcpu .dense_ball_query (new_xyz , xyz , radius , nsample , mode = 0 )
157-
158- @staticmethod
159- def backward (ctx , a = None ):
160- return None , None , None , None
161-
162-
163- class BallQueryPartialDense (Function ):
164- @staticmethod
165- def forward (ctx , radius , nsample , x , y , batch_x , batch_y ):
166- # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
167- if x .is_cuda :
168- return tpcuda .ball_query_partial_dense (x , y , batch_x , batch_y , radius , nsample )
169- else :
170- ind , dist = tpcpu .batch_ball_query (x , y , batch_x , batch_y , radius , nsample , mode = 0 )
171- return ind , dist
172-
173- @staticmethod
174- def backward (ctx , a = None ):
175- return None , None , None , None
149+ def ball_query_dense (radius , nsample , xyz , new_xyz , batch_xyz = None , batch_new_xyz = None , sort = False ):
150+ # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
151+ if new_xyz .is_cuda :
152+ if sort :
153+ raise NotImplementedError ("CUDA version does not sort the neighbors" )
154+ ind , dist = tpcuda .ball_query_dense (new_xyz , xyz , radius , nsample )
155+ else :
156+ ind , dist = tpcpu .dense_ball_query (new_xyz , xyz , radius , nsample , mode = 0 , sorted = sort )
157+ positive = dist > 0
158+ dist [positive ] = torch .sqrt (dist [positive ])
159+ return ind , dist
160+
161+
162+ def ball_query_partial_dense (radius , nsample , x , y , batch_x , batch_y , sort = False ):
163+ # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
164+ if x .is_cuda :
165+ if sort :
166+ raise NotImplementedError ("CUDA version does not sort the neighbors" )
167+ ind , dist = tpcuda .ball_query_partial_dense (x , y , batch_x , batch_y , radius , nsample )
168+ else :
169+ ind , dist = tpcpu .batch_ball_query (x , y , batch_x , batch_y , radius , nsample , mode = 0 , sorted = sort )
170+ positive = dist > 0
171+ dist [positive ] = torch .sqrt (dist [positive ])
172+ return ind , dist
176173
177174
178175def ball_query (
@@ -183,6 +180,7 @@ def ball_query(
183180 mode : Optional [str ] = "dense" ,
184181 batch_x : Optional [torch .tensor ] = None ,
185182 batch_y : Optional [torch .tensor ] = None ,
183+ sort : Optional [bool ] = False ,
186184) -> torch .Tensor :
187185 """
188186 Arguments:
@@ -197,11 +195,12 @@ def ball_query(
197195 Keyword Arguments:
198196 batch_x -- (M, ) [partial_dense] or (B, M, 3) [dense] Contains indexes to indicate within batch it belongs to.
199197 batch_y -- (N, ) Contains indexes to indicate within batch it belongs to
198+ sort -- bool wether the neighboors are sorted or not (closests first)
200199
201200 Returns:
202201 idx: (npoint, nsample) or (B, npoint, nsample) [dense] It contains the indexes of the element within x at radius distance to y
203- dist2 : (N, nsample) or (B, npoint, nsample) Default value: -1.
204- It contains the square distances of the element within x at radius distance to y
202+ dist : (N, nsample) or (B, npoint, nsample) Default value: -1.
203+ It contains the distance of the element within x at radius distance to y
205204 """
206205 if mode is None :
207206 raise Exception ('The mode should be defined within ["partial_dense | dense"]' )
@@ -212,12 +211,12 @@ def ball_query(
212211 assert x .size (0 ) == batch_x .size (0 )
213212 assert y .size (0 ) == batch_y .size (0 )
214213 assert x .dim () == 2
215- return BallQueryPartialDense . apply (radius , nsample , x , y , batch_x , batch_y )
214+ return ball_query_partial_dense (radius , nsample , x , y , batch_x , batch_y , sort = sort )
216215
217216 elif mode .lower () == "dense" :
218217 if (batch_x is not None ) or (batch_y is not None ):
219218 raise Exception ("batch_x and batch_y should not be provided" )
220219 assert x .dim () == 3
221- return BallQueryDense . apply (radius , nsample , x , y )
220+ return ball_query_dense (radius , nsample , x , y , sort = sort )
222221 else :
223222 raise Exception ("unrecognized mode {}" .format (mode ))
0 commit comments