1111 import torch_points .points_cuda as tpcuda
1212
1313
14- class FurthestPointSampling (Function ):
15- @staticmethod
16- def forward (ctx , xyz , npoint ):
17- if xyz .is_cuda :
18- return tpcuda .furthest_point_sampling (xyz , npoint )
19- else :
20- return tpcpu .fps (xyz , npoint , True )
21-
22- @staticmethod
23- def backward (xyz , a = None ):
24- return None , None
25-
26-
2714def furthest_point_sample (xyz , npoint ):
2815 # type: (Any, torch.Tensor, int) -> torch.Tensor
2916 r"""
@@ -42,24 +29,10 @@ def furthest_point_sample(xyz, npoint):
4229 torch.Tensor
4330 (B, npoint) tensor containing the set
4431 """
45- return FurthestPointSampling .apply (xyz , npoint )
46-
47-
48- class ThreeNN (Function ):
49- @staticmethod
50- def forward (ctx , unknown , known ):
51- # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
52-
53- if unknown .is_cuda :
54- dist2 , idx = tpcuda .three_nn (unknown , known )
55- else :
56- idx , dist2 = knn (known , unknown , 3 )
57-
58- return torch .sqrt (dist2 ), idx
59-
60- @staticmethod
61- def backward (ctx , a = None , b = None ):
62- return None , None
32+ if xyz .is_cuda :
33+ return tpcuda .furthest_point_sampling (xyz , npoint )
34+ else :
35+ return tpcpu .fps (xyz , npoint , True )
6336
6437
6538def three_nn (unknown , known ):
@@ -79,7 +52,12 @@ def three_nn(unknown, known):
7952 idx : torch.Tensor
8053 (B, n, 3) index of 3 nearest neighbors
8154 """
82- return ThreeNN .apply (unknown , known )
55+ if unknown .is_cuda :
56+ dist2 , idx = tpcuda .three_nn (unknown , known )
57+ else :
58+ idx , dist2 = knn (known , unknown , 3 )
59+
60+ return torch .sqrt (dist2 ), idx
8361
8462
8563class ThreeInterpolate (Function ):
0 commit comments