Skip to content

Commit ea6277f

Browse files
Merge pull request #55 from hzxie/chamfer-dist
Move the Chamfer Distance to a new file
2 parents 4e384b8 + 87f5aa5 commit ea6277f

File tree

3 files changed

+60
-44
lines changed

3 files changed

+60
-44
lines changed

test/test_chamfer_dist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")
1313
sys.path.insert(0, ROOT)
1414

15-
from torch_points_kernels import ChamferFunction, chamfer_dist
15+
from torch_points_kernels.chamfer_dist import ChamferFunction, chamfer_dist
1616

1717

1818
class TestChamferDistance(unittest.TestCase):
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import torch
2+
3+
if torch.cuda.is_available():
4+
import torch_points_kernels.points_cuda as tpcuda
5+
6+
7+
class ChamferFunction(torch.autograd.Function):
8+
@staticmethod
9+
def forward(ctx, xyz1, xyz2):
10+
if not torch.cuda.is_available():
11+
raise NotImplementedError(
12+
"CPU version is not available for Chamfer Distance"
13+
)
14+
15+
dist1, dist2, idx1, idx2 = tpcuda.chamfer_dist(xyz1, xyz2)
16+
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
17+
18+
return dist1, dist2
19+
20+
@staticmethod
21+
def backward(ctx, grad_dist1, grad_dist2):
22+
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
23+
grad_xyz1, grad_xyz2 = tpcuda.chamfer_dist_grad(
24+
xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2
25+
)
26+
return grad_xyz1, grad_xyz2
27+
28+
29+
def chamfer_dist(xyz1, xyz2, ignore_zeros=False):
30+
r"""
31+
Calcuates the distance between B pairs of point clouds
32+
33+
Parameters
34+
----------
35+
xyz1 : torch.Tensor (dtype=torch.float32)
36+
(B, n1, 3) B point clouds containing n1 points
37+
xyz2 : torch.Tensor (dtype=torch.float32)
38+
(B, n2, 3) B point clouds containing n2 points
39+
ignore_zeros : bool
40+
ignore the point whose coordinate is (0, 0, 0) or not
41+
42+
Returns
43+
-------
44+
dist: torch.Tensor
45+
(B, ): the distances between B pairs of point clouds
46+
"""
47+
if len(xyz1.shape) != 3 or xyz1.size(2) != 3 or len(xyz2.shape) != 3 or xyz2.size(2) != 3:
48+
raise ValueError('The input point cloud should be of size (B, n_pts, 3)')
49+
50+
batch_size = xyz1.size(0)
51+
if batch_size == 1 and ignore_zeros:
52+
non_zeros1 = torch.sum(xyz1, dim=2).ne(0)
53+
non_zeros2 = torch.sum(xyz2, dim=2).ne(0)
54+
xyz1 = xyz1[non_zeros1].unsqueeze(dim=0)
55+
xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)
56+
57+
dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
58+
return torch.mean(dist1) + torch.mean(dist2)
59+

torch_points_kernels/torchpoints.py

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -217,46 +217,3 @@ def ball_query(
217217
else:
218218
raise Exception("unrecognized mode {}".format(mode))
219219

220-
221-
class ChamferFunction(Function):
222-
@staticmethod
223-
def forward(ctx, xyz1, xyz2):
224-
dist1, dist2, idx1, idx2 = tpcuda.chamfer_dist(xyz1, xyz2)
225-
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
226-
227-
return dist1, dist2
228-
229-
@staticmethod
230-
def backward(ctx, grad_dist1, grad_dist2):
231-
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
232-
grad_xyz1, grad_xyz2 = tpcuda.chamfer_dist_grad(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2)
233-
return grad_xyz1, grad_xyz2
234-
235-
236-
def chamfer_dist(xyz1, xyz2, ignore_zeros=False):
237-
r"""
238-
Calcuates the distance between B pairs of point clouds
239-
240-
Parameters
241-
----------
242-
xyz1 : torch.Tensor (dtype=torch.float32)
243-
(B, n1, 3) B point clouds containing n1 points
244-
xyz2 : torch.Tensor (dtype=torch.float32)
245-
(B, n2, 3) B point clouds containing n2 points
246-
ignore_zeros : bool
247-
ignore the point whose coordinate is (0, 0, 0) or not
248-
249-
Returns
250-
-------
251-
dist: torch.Tensor
252-
(B, ): the distances between B pairs of point clouds
253-
"""
254-
batch_size = xyz1.size(0)
255-
if batch_size == 1 and ignore_zeros:
256-
non_zeros1 = torch.sum(xyz1, dim=2).ne(0)
257-
non_zeros2 = torch.sum(xyz2, dim=2).ne(0)
258-
xyz1 = xyz1[non_zeros1].unsqueeze(dim=0)
259-
xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)
260-
261-
dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
262-
return torch.mean(dist1) + torch.mean(dist2)

0 commit comments

Comments
 (0)