@@ -229,7 +229,9 @@ def forward(ctx, xyz1, xyz2):
229229 @staticmethod
230230 def backward (ctx , grad_dist1 , grad_dist2 ):
231231 xyz1 , xyz2 , idx1 , idx2 = ctx .saved_tensors
232- grad_xyz1 , grad_xyz2 = tpcuda .chamfer_dist_grad (xyz1 , xyz2 , idx1 , idx2 , grad_dist1 , grad_dist2 )
232+ grad_xyz1 , grad_xyz2 = tpcuda .chamfer_dist_grad (
233+ xyz1 , xyz2 , idx1 , idx2 , grad_dist1 , grad_dist2
234+ )
233235 return grad_xyz1 , grad_xyz2
234236
235237
@@ -260,3 +262,56 @@ def chamfer_dist(xyz1, xyz2, ignore_zeros=False):
260262
261263 dist1 , dist2 = ChamferFunction .apply (xyz1 , xyz2 )
262264 return torch .mean (dist1 ) + torch .mean (dist2 )
265+
266+
267+ class CubicFeatureSamplingFunction (torch .autograd .Function ):
268+ @staticmethod
269+ def forward (ctx , ptcloud , cubic_features , neighborhood_size = 1 ):
270+ scale = cubic_features .size (2 )
271+ point_features , grid_pt_indexes = tpcuda .cubic_feature_sampling (
272+ scale , neighborhood_size , ptcloud , cubic_features
273+ )
274+ ctx .save_for_backward (
275+ torch .Tensor ([scale ]), torch .Tensor ([neighborhood_size ]), grid_pt_indexes
276+ )
277+ return point_features
278+
279+ @staticmethod
280+ def backward (ctx , grad_point_features ):
281+ scale , neighborhood_size , grid_pt_indexes = ctx .saved_tensors
282+ scale = int (scale .item ())
283+ neighborhood_size = int (neighborhood_size .item ())
284+ grad_point_features = grad_point_features .contiguous ()
285+ grad_ptcloud , grad_cubic_features = tpcuda .cubic_feature_sampling_grad (
286+ scale , neighborhood_size , grad_point_features , grid_pt_indexes
287+ )
288+ return grad_ptcloud , grad_cubic_features , None
289+
290+
291+ def cubic_feature_sampling (ptcloud , cubic_features , neighborhood_size = 1 ):
292+ r"""
293+ Sample the features of points from 3D feature maps that the point lies in.
294+ Please refer to https://arxiv.org/pdf/2006.03761 for more information
295+
296+ Parameters
297+ ----------
298+ ptcloud : torch.Tensor (dtype=torch.float32)
299+ (B, n_pts, 3) point clouds containing n_pts points
300+ cubic_features : torch.Tensor (dtype=torch.float32)
301+ (B, c, m, m, m) 3D feature maps of sizes m x m x m and c channels
302+ neighborhood_size : int
303+ The neighborhood cubes to sample.
304+ neighborhood_size = 1 means to sample the cube that point lies in.
305+ neighborhood_size = 2 means to sample surrouding cubes (step = 1) of
306+ the cube that point lies in.
307+
308+ Returns
309+ -------
310+ dist: torch.Tensor
311+ (B, n_pts, n_vertices, c), where n_vertices = (neighborhood_size * 2)^3
312+ """
313+ h_scale = cubic_features .size (2 ) / 2
314+ ptcloud = ptcloud * h_scale + h_scale
315+ return CubicFeatureSamplingFunction .apply (
316+ ptcloud , cubic_features , neighborhood_size
317+ )
0 commit comments