Skip to content

Commit db6c376

Browse files
Merge branch 'master' into cubic-feature-sampling
2 parents be5d95e + ea6277f commit db6c376

File tree

3 files changed

+60
-48
lines changed

3 files changed

+60
-48
lines changed

test/test_chamfer_dist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")
1313
sys.path.insert(0, ROOT)
1414

15-
from torch_points_kernels import ChamferFunction, chamfer_dist
15+
from torch_points_kernels.chamfer_dist import ChamferFunction, chamfer_dist
1616

1717

1818
class TestChamferDistance(unittest.TestCase):
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import torch
2+
3+
if torch.cuda.is_available():
4+
import torch_points_kernels.points_cuda as tpcuda
5+
6+
7+
class ChamferFunction(torch.autograd.Function):
8+
@staticmethod
9+
def forward(ctx, xyz1, xyz2):
10+
if not torch.cuda.is_available():
11+
raise NotImplementedError(
12+
"CPU version is not available for Chamfer Distance"
13+
)
14+
15+
dist1, dist2, idx1, idx2 = tpcuda.chamfer_dist(xyz1, xyz2)
16+
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
17+
18+
return dist1, dist2
19+
20+
@staticmethod
21+
def backward(ctx, grad_dist1, grad_dist2):
22+
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
23+
grad_xyz1, grad_xyz2 = tpcuda.chamfer_dist_grad(
24+
xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2
25+
)
26+
return grad_xyz1, grad_xyz2
27+
28+
29+
def chamfer_dist(xyz1, xyz2, ignore_zeros=False):
30+
r"""
31+
Calcuates the distance between B pairs of point clouds
32+
33+
Parameters
34+
----------
35+
xyz1 : torch.Tensor (dtype=torch.float32)
36+
(B, n1, 3) B point clouds containing n1 points
37+
xyz2 : torch.Tensor (dtype=torch.float32)
38+
(B, n2, 3) B point clouds containing n2 points
39+
ignore_zeros : bool
40+
ignore the point whose coordinate is (0, 0, 0) or not
41+
42+
Returns
43+
-------
44+
dist: torch.Tensor
45+
(B, ): the distances between B pairs of point clouds
46+
"""
47+
if len(xyz1.shape) != 3 or xyz1.size(2) != 3 or len(xyz2.shape) != 3 or xyz2.size(2) != 3:
48+
raise ValueError('The input point cloud should be of size (B, n_pts, 3)')
49+
50+
batch_size = xyz1.size(0)
51+
if batch_size == 1 and ignore_zeros:
52+
non_zeros1 = torch.sum(xyz1, dim=2).ne(0)
53+
non_zeros2 = torch.sum(xyz2, dim=2).ne(0)
54+
xyz1 = xyz1[non_zeros1].unsqueeze(dim=0)
55+
xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)
56+
57+
dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
58+
return torch.mean(dist1) + torch.mean(dist2)
59+

torch_points_kernels/torchpoints.py

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -235,50 +235,3 @@ def ball_query(
235235
return ball_query_dense(radius, nsample, x, y, sort=sort)
236236
else:
237237
raise Exception("unrecognized mode {}".format(mode))
238-
239-
240-
class ChamferFunction(Function):
241-
@staticmethod
242-
def forward(ctx, xyz1, xyz2):
243-
dist1, dist2, idx1, idx2 = tpcuda.chamfer_dist(xyz1, xyz2)
244-
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
245-
246-
return dist1, dist2
247-
248-
@staticmethod
249-
def backward(ctx, grad_dist1, grad_dist2):
250-
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
251-
grad_xyz1, grad_xyz2 = tpcuda.chamfer_dist_grad(
252-
xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2
253-
)
254-
return grad_xyz1, grad_xyz2
255-
256-
257-
def chamfer_dist(xyz1, xyz2, ignore_zeros=False):
258-
r"""
259-
Calcuates the distance between B pairs of point clouds
260-
261-
Parameters
262-
----------
263-
xyz1 : torch.Tensor (dtype=torch.float32)
264-
(B, n1, 3) B point clouds containing n1 points
265-
xyz2 : torch.Tensor (dtype=torch.float32)
266-
(B, n2, 3) B point clouds containing n2 points
267-
ignore_zeros : bool
268-
ignore the point whose coordinate is (0, 0, 0) or not
269-
270-
Returns
271-
-------
272-
dist: torch.Tensor
273-
(B, ): the distances between B pairs of point clouds
274-
"""
275-
batch_size = xyz1.size(0)
276-
if batch_size == 1 and ignore_zeros:
277-
non_zeros1 = torch.sum(xyz1, dim=2).ne(0)
278-
non_zeros2 = torch.sum(xyz2, dim=2).ne(0)
279-
xyz1 = xyz1[non_zeros1].unsqueeze(dim=0)
280-
xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)
281-
282-
dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
283-
return torch.mean(dist1) + torch.mean(dist2)
284-

0 commit comments

Comments
 (0)