Skip to content

Commit edb7176

Browse files
author
humanpose1
committed
problem of seg fault for small batches :'(
1 parent c0daf2c commit edb7176

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

test/test_ballquerry_partial.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def test_simple_gpu(self):
1111
y = torch.tensor([[0, 0, 0]]).to(torch.float).cuda()
1212
batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long().cuda()
1313
batch_y = torch.from_numpy(np.asarray([0])).long().cuda()
14-
14+
1515
batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long().cuda()
1616
batch_y = torch.from_numpy(np.asarray([0])).long().cuda()
1717

@@ -26,5 +26,25 @@ def test_simple_gpu(self):
2626
npt.assert_array_almost_equal(idx, idx_answer)
2727
npt.assert_array_almost_equal(dist2, dist2_answer)
2828

29+
def test_simple_cpu(self):
30+
x = torch.tensor([[10, 0, 0], [0.1, 0, 0], [10, 0, 0], [0.1, 0, 0]]).to(torch.float)
31+
y = torch.tensor([[0, 0, 0]]).to(torch.float)
32+
batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long()
33+
batch_y = torch.from_numpy(np.asarray([0])).long()
34+
35+
batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long()
36+
batch_y = torch.from_numpy(np.asarray([0])).long()
37+
38+
idx, dist2 = ball_query(1., 2, x, y, batch_x, batch_y, mode="PARTIAL_DENSE")
39+
40+
idx = idx.detach().cpu().numpy()
41+
dist2 = dist2.detach().cpu().numpy()
42+
43+
idx_answer = np.asarray([[1, 1], [0, 1], [1, 1], [1, 1]])
44+
dist2_answer = np.asarray([[-1, -1], [0.01, -1], [-1, -1], [-1, -1]]).astype(np.float32)
45+
46+
npt.assert_array_almost_equal(idx, idx_answer)
47+
npt.assert_array_almost_equal(dist2, dist2_answer)
48+
2949
if __name__ == "__main__":
3050
unittest.main()

torch_points/torchpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def forward(ctx, radius, nsample, x, y, batch_x, batch_y):
292292
batch_y,
293293
radius, nsample)
294294
else:
295-
ind, dist = tpcpu.dense_ball_query(x, y,
295+
ind, dist = tpcpu.batch_ball_query(x, y,
296296
batch_x,
297297
batch_y,
298298
radius, nsample, mode=0)

0 commit comments

Comments
 (0)