|
| 1 | +import torch |
| 2 | + |
| 3 | +if torch.cuda.is_available(): |
| 4 | + import torch_points_kernels.points_cuda as tpcuda |
| 5 | + |
| 6 | + |
| 7 | +class ChamferFunction(torch.autograd.Function): |
| 8 | + @staticmethod |
| 9 | + def forward(ctx, xyz1, xyz2): |
| 10 | + if not torch.cuda.is_available(): |
| 11 | + raise NotImplementedError( |
| 12 | + "CPU version is not available for Chamfer Distance" |
| 13 | + ) |
| 14 | + |
| 15 | + dist1, dist2, idx1, idx2 = tpcuda.chamfer_dist(xyz1, xyz2) |
| 16 | + ctx.save_for_backward(xyz1, xyz2, idx1, idx2) |
| 17 | + |
| 18 | + return dist1, dist2 |
| 19 | + |
| 20 | + @staticmethod |
| 21 | + def backward(ctx, grad_dist1, grad_dist2): |
| 22 | + xyz1, xyz2, idx1, idx2 = ctx.saved_tensors |
| 23 | + grad_xyz1, grad_xyz2 = tpcuda.chamfer_dist_grad( |
| 24 | + xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2 |
| 25 | + ) |
| 26 | + return grad_xyz1, grad_xyz2 |
| 27 | + |
| 28 | + |
| 29 | +def chamfer_dist(xyz1, xyz2, ignore_zeros=False): |
| 30 | + r""" |
| 31 | + Calcuates the distance between B pairs of point clouds |
| 32 | +
|
| 33 | + Parameters |
| 34 | + ---------- |
| 35 | + xyz1 : torch.Tensor (dtype=torch.float32) |
| 36 | + (B, n1, 3) B point clouds containing n1 points |
| 37 | + xyz2 : torch.Tensor (dtype=torch.float32) |
| 38 | + (B, n2, 3) B point clouds containing n2 points |
| 39 | + ignore_zeros : bool |
| 40 | + ignore the point whose coordinate is (0, 0, 0) or not |
| 41 | +
|
| 42 | + Returns |
| 43 | + ------- |
| 44 | + dist: torch.Tensor |
| 45 | + (B, ): the distances between B pairs of point clouds |
| 46 | + """ |
| 47 | + if len(xyz1.shape) != 3 or xyz1.size(2) != 3 or len(xyz2.shape) != 3 or xyz2.size(2) != 3: |
| 48 | + raise ValueError('The input point cloud should be of size (B, n_pts, 3)') |
| 49 | + |
| 50 | + batch_size = xyz1.size(0) |
| 51 | + if batch_size == 1 and ignore_zeros: |
| 52 | + non_zeros1 = torch.sum(xyz1, dim=2).ne(0) |
| 53 | + non_zeros2 = torch.sum(xyz2, dim=2).ne(0) |
| 54 | + xyz1 = xyz1[non_zeros1].unsqueeze(dim=0) |
| 55 | + xyz2 = xyz2[non_zeros2].unsqueeze(dim=0) |
| 56 | + |
| 57 | + dist1, dist2 = ChamferFunction.apply(xyz1, xyz2) |
| 58 | + return torch.mean(dist1) + torch.mean(dist2) |
| 59 | + |
0 commit comments