Skip to content

Commit 48606cf

Browse files
Address review comments
1 parent 3721d15 commit 48606cf

File tree

3 files changed

+9
-4
lines changed

3 files changed

+9
-4
lines changed

cpu/src/interpolate.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ at::Tensor knn_interpolate(at::Tensor features, at::Tensor idx, at::Tensor weigh
3030
{
3131
output_a[b][c][p] = 0;
3232
for (int i = 0; i < idx.size(2); i++)
33-
output_a[b][c][p] += features_a[b][c][idx_a[b][p][i]] * weight_a[b][p][i];
33+
{
34+
auto new_idx = idx_a[b][p][i];
35+
output_a[b][c][p] += features_a[b][c][new_idx] * weight_a[b][p][i];
36+
}
3437
}
3538
}
3639
}

cpu/src/neighbors.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,6 @@ void nanoflann_knn_neighbors(vector<scalar_t>& queries, vector<scalar_t>& suppor
321321
neighbors_indices[i + current_pos] = ret_index[i];
322322
dists[i + current_pos] = out_dist_sqr[i];
323323
}
324-
current_pos += nMatches;
324+
current_pos += k;
325325
}
326326
}

torch_points/knn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch_points.points_cpu as tpcpu
22

3+
34
def knn(pos_support, pos, k):
45
""" Dense knn serach
56
Arguments:
@@ -12,5 +13,6 @@ def knn(pos_support, pos, k):
1213
dist2 - [B,M,k] squared distances
1314
"""
1415
assert pos_support.dim() == 3 and pos.dim() == 3
15-
16-
return tpcpu.dense_knn(pos_support, pos, k)
16+
if pos_support.is_cuda:
17+
raise ValueError("CUDA version not implemented, use pytorch geometric")
18+
return tpcpu.dense_knn(pos_support, pos, k)

0 commit comments

Comments
 (0)