Skip to content

Commit 488825f

Browse files
committed
Create more test cases for the forward function.
1 parent e9ee3bd commit 488825f

File tree

2 files changed

+48
-4
lines changed

2 files changed

+48
-4
lines changed

test/test_chamfer_dist.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,51 @@
1+
import numpy as np
12
import os
23
import sys
34
import torch
45
import unittest
56

67
from torch.autograd import gradcheck
78

9+
from . import run_if_cuda
10+
11+
812
ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")
913
sys.path.insert(0, ROOT)
1014

11-
from torch_points_kernels import ChamferFunction
15+
from torch_points_kernels import ChamferFunction, chamfer_dist
1216

1317

1418
class TestChamferDistance(unittest.TestCase):
15-
def test_chamfer_dist(self):
19+
@run_if_cuda
20+
def test_chamfer_dist_grad(self):
1621
x = torch.rand(4, 64, 3).double()
1722
y = torch.rand(4, 128, 3).double()
1823
x.requires_grad = True
1924
y.requires_grad = True
2025
test = gradcheck(ChamferFunction.apply, [x.cuda(), y.cuda()])
2126

27+
@run_if_cuda
28+
def test_chamfer_dist(self):
29+
xyz1 = torch.from_numpy(np.array([[
30+
[0, 0, 0],
31+
[1, 1, 1],
32+
[2, 0, 1]
33+
]])).float()
34+
xyz2 = torch.from_numpy(np.array([[[1, 0, 0], [1, 2, 1]]])).float()
35+
dist = chamfer_dist(xyz1.cuda(), xyz2.cuda())
36+
self.assertAlmostEqual(dist.item(), 2.333333, places=5)
37+
38+
@run_if_cuda
39+
def test_chamfer_dist_ignore_zeros(self):
40+
xyz1 = torch.from_numpy(np.array([[
41+
[0, 0, 0],
42+
[1, 1, 1],
43+
[2, 0, 1]
44+
]])).float()
45+
xyz2 = torch.from_numpy(np.array([[[1, 0, 0], [1, 2, 1]]])).float()
46+
dist = chamfer_dist(xyz1.cuda(), xyz2.cuda(), True)
47+
self.assertAlmostEqual(dist.item(), 3.0, places=5)
48+
2249

23-
if __name__ == '__main__':
50+
if __name__ == "__main__":
2451
unittest.main()

torch_points_kernels/torchpoints.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,24 @@ def backward(ctx, grad_dist1, grad_dist2):
252252
return grad_xyz1, grad_xyz2
253253

254254

255-
def chamfer_dist(self, xyz1, xyz2, ignore_zeros=False):
255+
def chamfer_dist(xyz1, xyz2, ignore_zeros=False):
256+
r"""
257+
Calcuates the distance between B pairs of point clouds
258+
259+
Parameters
260+
----------
261+
xyz1 : torch.Tensor (dtype=torch.float32)
262+
(B, n1, 3) B point clouds containing n1 points
263+
xyz2 : torch.Tensor (dtype=torch.float32)
264+
(B, n2, 3) B point clouds containing n2 points
265+
ignore_zeros : bool
266+
ignore the point whose coordinate is (0, 0, 0) or not
267+
268+
Returns
269+
-------
270+
dist: torch.Tensor
271+
(B, ): the distances between B pairs of point clouds
272+
"""
256273
batch_size = xyz1.size(0)
257274
if batch_size == 1 and ignore_zeros:
258275
non_zeros1 = torch.sum(xyz1, dim=2).ne(0)

0 commit comments

Comments
 (0)