33import torch .nn as nn
44import sys
55
6- import torch_points .points_cuda as tpcuda
6+ import torch_points .points_cpu as tpcpu
7+
8+ if torch .cuda .is_available ():
9+ import torch_points .points_cuda as tpcuda
710
811
912class FurthestPointSampling (Function ):
1013 @staticmethod
1114 def forward (ctx , xyz , npoint ):
12- return tpcuda .furthest_point_sampling (xyz , npoint )
15+ if xyz .is_cuda :
16+ return tpcuda .furthest_point_sampling (xyz , npoint )
17+ else :
18+ raise NotImplementedError
1319
1420 @staticmethod
1521 def backward (xyz , a = None ):
@@ -45,14 +51,20 @@ def forward(ctx, features, idx):
4551
4652 ctx .for_backwards = (idx , C , N )
4753
48- return tpcuda .gather_points (features , idx )
54+ if features .is_cuda :
55+ return tpcuda .gather_points (features , idx )
56+ else :
57+ return tpcpu .gather_points (features , idx )
4958
5059 @staticmethod
5160 def backward (ctx , grad_out ):
5261 idx , C , N = ctx .for_backwards
5362
54- grad_features = tpcuda .gather_points_grad (grad_out .contiguous (), idx , N )
55- return grad_features , None
63+ if grad_out .is_cuda :
64+ grad_features = tpcuda .gather_points_grad (grad_out .contiguous (), idx , N )
65+ return grad_features , None
66+ else :
67+ raise NotImplementedError
5668
5769
5870def gather_operation (features , idx ):
@@ -64,12 +76,12 @@ def gather_operation(features, idx):
6476 (B, C, N) tensor
6577
6678 idx : torch.Tensor
67- (B, npoint) tensor of the features to gather
79+ (B, npoint, nsample ) tensor of the features to gather
6880
6981 Returns
7082 -------
7183 torch.Tensor
72- (B, C, npoint) tensor
84+ (B, C, npoint, nsample ) tensor
7385 """
7486 return GatherOperation .apply (features , idx )
7587
@@ -78,7 +90,11 @@ class ThreeNN(Function):
7890 @staticmethod
7991 def forward (ctx , unknown , known ):
8092 # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
81- dist2 , idx = tpcuda .three_nn (unknown , known )
93+
94+ if unknown .is_cuda :
95+ dist2 , idx = tpcuda .three_nn (unknown , known )
96+ else :
97+ raise NotImplementedError
8298
8399 return torch .sqrt (dist2 ), idx
84100
@@ -116,7 +132,10 @@ def forward(ctx, features, idx, weight):
116132
117133 ctx .three_interpolate_for_backward = (idx , weight , m )
118134
119- return tpcuda .three_interpolate (features , idx , weight )
135+ if features .is_cuda :
136+ return tpcuda .three_interpolate (features , idx , weight )
137+ else :
138+ raise NotImplementedError
120139
121140 @staticmethod
122141 def backward (ctx , grad_out ):
@@ -138,9 +157,12 @@ def backward(ctx, grad_out):
138157 """
139158 idx , weight , m = ctx .three_interpolate_for_backward
140159
141- grad_features = tpcuda .three_interpolate_grad (
142- grad_out .contiguous (), idx , weight , m
143- )
160+ if grad_out .is_cuda :
161+ grad_features = tpcuda .three_interpolate_grad (
162+ grad_out .contiguous (), idx , weight , m
163+ )
164+ else :
165+ raise NotImplementedError
144166
145167 return grad_features , None , None
146168
@@ -174,7 +196,10 @@ def forward(ctx, features, idx):
174196
175197 ctx .for_backwards = (idx , N )
176198
177- return tpcuda .group_points (features , idx )
199+ if features .is_cuda :
200+ return tpcuda .group_points (features , idx )
201+ else :
202+ return tpcpu .group_points (features , idx )
178203
179204 @staticmethod
180205 def backward (ctx , grad_out ):
@@ -194,7 +219,10 @@ def backward(ctx, grad_out):
194219 """
195220 idx , N = ctx .for_backwards
196221
197- grad_features = tpcuda .group_points_grad (grad_out .contiguous (), idx , N )
222+ if grad_out .is_cuda :
223+ grad_features = tpcuda .group_points_grad (grad_out .contiguous (), idx , N )
224+ else :
225+ raise NotImplementedError
198226
199227 return grad_features , None
200228
@@ -220,7 +248,10 @@ class BallQuery(Function):
220248 @staticmethod
221249 def forward (ctx , radius , nsample , xyz , new_xyz ):
222250 # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
223- return tpcuda .ball_query (new_xyz , xyz , radius , nsample )
251+ if new_xyz .is_cuda :
252+ return tpcuda .ball_query (new_xyz , xyz , radius , nsample )
253+ else :
254+ raise NotImplementedError
224255
225256 @staticmethod
226257 def backward (ctx , a = None ):
0 commit comments