@@ -30,7 +30,10 @@ def furthest_point_sample(xyz, npoint):
3030 (B, npoint) tensor containing the set
3131 """
3232 if npoint > xyz .shape [1 ]:
33- raise ValueError ("caanot sample %i points from an input set of %i points" % (npoint , xyz .shape [1 ]))
33+ raise ValueError (
34+ "caanot sample %i points from an input set of %i points"
35+ % (npoint , xyz .shape [1 ])
36+ )
3437 if xyz .is_cuda :
3538 return tpcuda .furthest_point_sampling (xyz , npoint )
3639 else :
@@ -99,9 +102,13 @@ def backward(ctx, grad_out):
99102 idx , weight , m = ctx .three_interpolate_for_backward
100103
101104 if grad_out .is_cuda :
102- grad_features = tpcuda .three_interpolate_grad (grad_out .contiguous (), idx , weight , m )
105+ grad_features = tpcuda .three_interpolate_grad (
106+ grad_out .contiguous (), idx , weight , m
107+ )
103108 else :
104- grad_features = tpcpu .knn_interpolate_grad (grad_out .contiguous (), idx , weight , m )
109+ grad_features = tpcpu .knn_interpolate_grad (
110+ grad_out .contiguous (), idx , weight , m
111+ )
105112
106113 return grad_features , None , None
107114
@@ -143,17 +150,23 @@ def grouping_operation(features, idx):
143150 all_idx = idx .reshape (idx .shape [0 ], - 1 )
144151 all_idx = all_idx .unsqueeze (1 ).repeat (1 , features .shape [1 ], 1 )
145152 grouped_features = features .gather (2 , all_idx )
146- return grouped_features .reshape (idx .shape [0 ], features .shape [1 ], idx .shape [1 ], idx .shape [2 ])
153+ return grouped_features .reshape (
154+ idx .shape [0 ], features .shape [1 ], idx .shape [1 ], idx .shape [2 ]
155+ )
147156
148157
149- def ball_query_dense (radius , nsample , xyz , new_xyz , batch_xyz = None , batch_new_xyz = None , sort = False ):
158+ def ball_query_dense (
159+ radius , nsample , xyz , new_xyz , batch_xyz = None , batch_new_xyz = None , sort = False
160+ ):
150161 # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
151162 if new_xyz .is_cuda :
152163 if sort :
153164 raise NotImplementedError ("CUDA version does not sort the neighbors" )
154165 ind , dist = tpcuda .ball_query_dense (new_xyz , xyz , radius , nsample )
155166 else :
156- ind , dist = tpcpu .dense_ball_query (new_xyz , xyz , radius , nsample , mode = 0 , sorted = sort )
167+ ind , dist = tpcpu .dense_ball_query (
168+ new_xyz , xyz , radius , nsample , mode = 0 , sorted = sort
169+ )
157170 return ind , dist
158171
159172
@@ -162,9 +175,13 @@ def ball_query_partial_dense(radius, nsample, x, y, batch_x, batch_y, sort=False
162175 if x .is_cuda :
163176 if sort :
164177 raise NotImplementedError ("CUDA version does not sort the neighbors" )
165- ind , dist = tpcuda .ball_query_partial_dense (x , y , batch_x , batch_y , radius , nsample )
178+ ind , dist = tpcuda .ball_query_partial_dense (
179+ x , y , batch_x , batch_y , radius , nsample
180+ )
166181 else :
167- ind , dist = tpcpu .batch_ball_query (x , y , batch_x , batch_y , radius , nsample , mode = 0 , sorted = sort )
182+ ind , dist = tpcpu .batch_ball_query (
183+ x , y , batch_x , batch_y , radius , nsample , mode = 0 , sorted = sort
184+ )
168185 return ind , dist
169186
170187
@@ -207,7 +224,9 @@ def ball_query(
207224 assert x .size (0 ) == batch_x .size (0 )
208225 assert y .size (0 ) == batch_y .size (0 )
209226 assert x .dim () == 2
210- return ball_query_partial_dense (radius , nsample , x , y , batch_x , batch_y , sort = sort )
227+ return ball_query_partial_dense (
228+ radius , nsample , x , y , batch_x , batch_y , sort = sort
229+ )
211230
212231 elif mode .lower () == "dense" :
213232 if (batch_x is not None ) or (batch_y is not None ):
@@ -262,56 +281,3 @@ def chamfer_dist(xyz1, xyz2, ignore_zeros=False):
262281
263282 dist1 , dist2 = ChamferFunction .apply (xyz1 , xyz2 )
264283 return torch .mean (dist1 ) + torch .mean (dist2 )
265-
266-
267- class CubicFeatureSamplingFunction (torch .autograd .Function ):
268- @staticmethod
269- def forward (ctx , ptcloud , cubic_features , neighborhood_size = 1 ):
270- scale = cubic_features .size (2 )
271- point_features , grid_pt_indexes = tpcuda .cubic_feature_sampling (
272- scale , neighborhood_size , ptcloud , cubic_features
273- )
274- ctx .save_for_backward (
275- torch .Tensor ([scale ]), torch .Tensor ([neighborhood_size ]), grid_pt_indexes
276- )
277- return point_features
278-
279- @staticmethod
280- def backward (ctx , grad_point_features ):
281- scale , neighborhood_size , grid_pt_indexes = ctx .saved_tensors
282- scale = int (scale .item ())
283- neighborhood_size = int (neighborhood_size .item ())
284- grad_point_features = grad_point_features .contiguous ()
285- grad_ptcloud , grad_cubic_features = tpcuda .cubic_feature_sampling_grad (
286- scale , neighborhood_size , grad_point_features , grid_pt_indexes
287- )
288- return grad_ptcloud , grad_cubic_features , None
289-
290-
291- def cubic_feature_sampling (ptcloud , cubic_features , neighborhood_size = 1 ):
292- r"""
293- Sample the features of points from 3D feature maps that the point lies in.
294- Please refer to https://arxiv.org/pdf/2006.03761 for more information
295-
296- Parameters
297- ----------
298- ptcloud : torch.Tensor (dtype=torch.float32)
299- (B, n_pts, 3) point clouds containing n_pts points
300- cubic_features : torch.Tensor (dtype=torch.float32)
301- (B, c, m, m, m) 3D feature maps of sizes m x m x m and c channels
302- neighborhood_size : int
303- The neighborhood cubes to sample.
304- neighborhood_size = 1 means to sample the cube that point lies in.
305- neighborhood_size = 2 means to sample surrouding cubes (step = 1) of
306- the cube that point lies in.
307-
308- Returns
309- -------
310- dist: torch.Tensor
311- (B, n_pts, n_vertices, c), where n_vertices = (neighborhood_size * 2)^3
312- """
313- h_scale = cubic_features .size (2 ) / 2
314- ptcloud = ptcloud * h_scale + h_scale
315- return CubicFeatureSamplingFunction .apply (
316- ptcloud , cubic_features , neighborhood_size
317- )
0 commit comments