Skip to content

Commit a210ff9

Browse files
Removing useless autograd
1 parent 6739a3f commit a210ff9

File tree

1 file changed

+10
-32
lines changed

1 file changed

+10
-32
lines changed

torch_points/torchpoints.py

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,6 @@
1111
import torch_points.points_cuda as tpcuda
1212

1313

14-
class FurthestPointSampling(Function):
15-
@staticmethod
16-
def forward(ctx, xyz, npoint):
17-
if xyz.is_cuda:
18-
return tpcuda.furthest_point_sampling(xyz, npoint)
19-
else:
20-
return tpcpu.fps(xyz, npoint, True)
21-
22-
@staticmethod
23-
def backward(xyz, a=None):
24-
return None, None
25-
26-
2714
def furthest_point_sample(xyz, npoint):
2815
# type: (Any, torch.Tensor, int) -> torch.Tensor
2916
r"""
@@ -42,24 +29,10 @@ def furthest_point_sample(xyz, npoint):
4229
torch.Tensor
4330
(B, npoint) tensor containing the set
4431
"""
45-
return FurthestPointSampling.apply(xyz, npoint)
46-
47-
48-
class ThreeNN(Function):
49-
@staticmethod
50-
def forward(ctx, unknown, known):
51-
# type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
52-
53-
if unknown.is_cuda:
54-
dist2, idx = tpcuda.three_nn(unknown, known)
55-
else:
56-
idx, dist2 = knn(known, unknown, 3)
57-
58-
return torch.sqrt(dist2), idx
59-
60-
@staticmethod
61-
def backward(ctx, a=None, b=None):
62-
return None, None
32+
if xyz.is_cuda:
33+
return tpcuda.furthest_point_sampling(xyz, npoint)
34+
else:
35+
return tpcpu.fps(xyz, npoint, True)
6336

6437

6538
def three_nn(unknown, known):
@@ -79,7 +52,12 @@ def three_nn(unknown, known):
7952
idx : torch.Tensor
8053
(B, n, 3) index of 3 nearest neighbors
8154
"""
82-
return ThreeNN.apply(unknown, known)
55+
if unknown.is_cuda:
56+
dist2, idx = tpcuda.three_nn(unknown, known)
57+
else:
58+
idx, dist2 = knn(known, unknown, 3)
59+
60+
return torch.sqrt(dist2), idx
8361

8462

8563
class ThreeInterpolate(Function):

0 commit comments

Comments
 (0)