99 import torch_points .points_cuda as tpcuda
1010
1111
12-
1312class FurthestPointSampling (Function ):
1413 @staticmethod
1514 def forward (ctx , xyz , npoint ):
@@ -289,15 +288,16 @@ def ball_query_dense(radius, nsample, xyz, new_xyz):
289288 """
290289 return BallQueryDense .apply (radius , nsample , xyz , new_xyz )
291290
291+
292292class BallQueryPartialDense (Function ):
293293 @staticmethod
294294 def forward (ctx , radius , nsample , x , y , batch_x , batch_y ):
295295 # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
296296 if x .is_cuda :
297- return tpcuda .ball_query_partial_dense (x , y ,
298- batch_x ,
299- batch_y ,
300- radius , nsample )
297+ return tpcuda .ball_query_partial_dense (x , y ,
298+ batch_x ,
299+ batch_y ,
300+ radius , nsample )
301301 else :
302302 raise NotImplementedError
303303
@@ -315,7 +315,7 @@ def ball_query_partial_dense(radius, nsample, x, y, batch_x, batch_y):
315315 nsample : int
316316 maximum number of features in the balls
317317 x : torch.Tensor
318- (M, 3) xyz coordinates of the features
318+ (M, 3) xyz coordinates of the features (The neighbours are going to be looked for there)
319319 y : torch.Tensor
320320 (N, npoint, 3) centers of the ball query
321321 batch_x : torch.Tensor
@@ -326,11 +326,12 @@ def ball_query_partial_dense(radius, nsample, x, y, batch_x, batch_y):
326326 Returns
327327 -------
328328 torch.Tensor
329- idx: (M , nsample) Default value: N. It contains the indexes of the element within y at radius distance to x
330- dist2: (M , nsample) Default value: -1. It contains the square distances of the element within y at radius distance to x
329+ idx: (N , nsample) Default value: N. It contains the indexes of the element within y at radius distance to x
330+ dist2: (N , nsample) Default value: -1. It contains the square distances of the element within y at radius distance to x
331331 """
332332 return BallQueryPartialDense .apply (radius , nsample , x , y , batch_x , batch_y )
333333
334+
334335def ball_query (radius : float , nsample : int , x , y , batch_x = None , batch_y = None , mode = None ):
335336 if mode is None :
336337 raise Exception ('The mode should be defined within ["PARTIAL_DENSE | DENSE"]' )
@@ -347,4 +348,4 @@ def ball_query(radius: float, nsample: int, x, y, batch_x=None, batch_y=None, mo
347348 raise Exception ('batch_x and batch_y should not be provided' )
348349 return ball_query_dense (radius , nsample , x , y )
349350 else :
350- raise Exception ('unrecognized mode {}' .format (mode ))
351+ raise Exception ('unrecognized mode {}' .format (mode ))
0 commit comments