Skip to content

Commit 31c06d8

Browse files
Fix tests
1 parent 340677a commit 31c06d8

File tree

4 files changed

+5
-5
lines changed

4 files changed

+5
-5
lines changed

cpu/include/neighbors.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ int batch_nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& suppo
2020

2121
template <typename scalar_t>
2222
void nanoflann_knn_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
23-
vector<long>& neighbors_indices, vector<float>& dists, int k);
23+
vector<long>& neighbors_indices, vector<float>& dists, int k);

cpu/src/knn.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ std::pair<at::Tensor, at::Tensor> _single_batch_knn(at::Tensor support, at::Tens
1616
std::vector<float> neighbors_dists(query.size(0) * k, -1);
1717

1818
auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
19-
auto options_dist = torch::TensorOptions().dtype(query.scalar_type()).device(torch::kCPU);
19+
auto options_dist = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU);
2020
AT_DISPATCH_ALL_TYPES(query.scalar_type(), "knn", [&] {
2121
auto data_q = query.DATA_PTR<scalar_t>();
2222
auto data_s = support.DATA_PTR<scalar_t>();

test/test_interpolate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_grad(self):
3636
dist_recip = 1.0 / (dist + 1e-8)
3737
norm = torch.sum(dist_recip, dim=2, keepdim=True)
3838
weight = dist_recip / norm
39-
input = (x, idx, weight)
39+
input = (x, idx, weight.double())
4040
test = gradcheck(three_interpolate, input, eps=1e-6, atol=1e-4)
4141

4242

test/test_knn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212

1313
class TestKnn(unittest.TestCase):
1414
def test_cpu(self):
15-
support = torch.tensor([[[0, 0, 0], [1, 0, 0], [2, 0, 0]]])
16-
query = torch.tensor([[[0, 0, 0]]])
15+
support = torch.tensor([[[0, 0, 0], [1, 0, 0], [2, 0, 0]]]).float()
16+
query = torch.tensor([[[0, 0, 0]]]).float()
1717

1818
idx, dist = knn(support, query, 3)
1919
torch.testing.assert_allclose(idx, torch.tensor([[[0, 1, 2]]]))

0 commit comments

Comments
 (0)