Skip to content

Commit f1cb6c0

Browse files
author
humanpose1
committed
add unittest
1 parent 9ebec86 commit f1cb6c0

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

test/test_ballquerry.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,24 @@
66

77

88
class TestBall(unittest.TestCase):
9-
def test_simple(self):
9+
def test_simple_gpu(self):
1010
a = torch.tensor([[[0, 0, 0], [1, 0, 0], [2, 0, 0]]]).to(torch.float).cuda()
1111
b = torch.tensor([[[0, 0, 0]]]).to(torch.float).cuda()
1212

1313
npt.assert_array_equal(ball_query(1, 2, a, b).detach().cpu().numpy(), np.array([[[0, 0]]]))
1414

15+
def test_simple_cpu(self):
16+
a = torch.tensor([[[0, 0, 0], [1, 0, 0], [2, 0, 0]]]).to(torch.float)
17+
b = torch.tensor([[[0, 0, 0]]]).to(torch.float)
18+
npt.assert_array_equal(ball_query(1, 2, a, b).detach().numpy(), np.array([[[0, 0]]]))
19+
20+
def test_cpu_gpu_equality(self):
21+
a = torch.randn(5, 1000, 3)
22+
npt.assert_array_equal(ball_query(0.1, 17, a, a).detach().numpy(),
23+
ball_query(0.1, 17, a.cuda(), a.cuda()).detach().numpy())
24+
25+
26+
1527

1628
if __name__ == "__main__":
1729
unittest.main()

0 commit comments

Comments
 (0)