|
6 | 6 |
|
7 | 7 |
|
8 | 8 | class TestBall(unittest.TestCase): |
9 | | - def test_simple(self): |
| 9 | + def test_simple_gpu(self): |
10 | 10 | a = torch.tensor([[[0, 0, 0], [1, 0, 0], [2, 0, 0]]]).to(torch.float).cuda() |
11 | 11 | b = torch.tensor([[[0, 0, 0]]]).to(torch.float).cuda() |
12 | 12 |
|
13 | 13 | npt.assert_array_equal(ball_query(1, 2, a, b).detach().cpu().numpy(), np.array([[[0, 0]]])) |
14 | 14 |
|
| 15 | + def test_simple_cpu(self): |
| 16 | + a = torch.tensor([[[0, 0, 0], [1, 0, 0], [2, 0, 0]]]).to(torch.float) |
| 17 | + b = torch.tensor([[[0, 0, 0]]]).to(torch.float) |
| 18 | + npt.assert_array_equal(ball_query(1, 2, a, b).detach().numpy(), np.array([[[0, 0]]])) |
| 19 | + |
| 20 | + def test_cpu_gpu_equality(self): |
| 21 | + a = torch.randn(5, 1000, 3) |
| 22 | + npt.assert_array_equal(ball_query(0.1, 17, a, a).detach().numpy(), |
| 23 | + ball_query(0.1, 17, a.cuda(), a.cuda()).detach().numpy()) |
| 24 | + |
| 25 | + |
| 26 | + |
15 | 27 |
|
16 | 28 | if __name__ == "__main__": |
17 | 29 | unittest.main() |
0 commit comments