Skip to content

Commit c0daf2c

Browse files
author
humanpose1
committed
fix problem of cpu
1 parent 01407c7 commit c0daf2c

File tree

6 files changed

+65
-25
lines changed

6 files changed

+65
-25
lines changed

cpu/include/ball_query.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,7 @@ std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor query,
99
at::Tensor query_batch,
1010
at::Tensor support_batch,
1111
float radius, int max_num, int mode);
12+
13+
std::pair<at::Tensor, at::Tensor> dense_ball_query(at::Tensor query,
14+
at::Tensor support,
15+
float radius, int max_num, int mode);

cpu/src/bindings.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
3535
"mode=1 means a matrix of edges of size Num_edge x 2"
3636
"return a tensor of size N1 x M where M is either max_num or the maximum number of neighbors found if mode = 0, if mode=1 return a tensor of size Num_edge x 2 and return a tensor containing the squared distance of the neighbors",
3737
"query"_a, "support"_a, "query_batch"_a, "support_batch"_a, "radius"_a, "max_num"_a=-1, "mode"_a=0);
38+
m.def("dense_ball_query", &dense_ball_query,
39+
"compute the radius search of a batch of point cloud using nanoflann"
40+
"-query : a pytorch tensor of size B x N1 x 3,. used to query the nearest neighbors"
41+
"- support : a pytorch tensor of size B x N2 x 3. used to build the tree"
42+
"- radius : float number, size of the ball for the radius search."
43+
"- max_num : int number, indicate the maximum of neaghbors allowed(if -1 then all the possible neighbors will be computed). "
44+
" - mode : int number that indicate which format for the neighborhood"
45+
" mode=0 mean a matrix of neighbors(-1 for shadow neighbors)"
46+
"mode=1 means a matrix of edges of size Num_edge x 2"
47+
"return a tensor of size N1 x M where M is either max_num or the maximum number of neighbors found if mode = 0, if mode=1 return a tensor of size Num_edge x 2 and return a tensor containing the squared distance of the neighbors",
48+
"query"_a, "support"_a, "radius"_a, "max_num"_a=-1, "mode"_a=0);
3849
}

cpu/src/neighbors.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,20 @@ int nanoflann_neighbors(vector<scalar_t>& queries,
8282

8383
i0 = 0;
8484

85+
int token = 0;
8586
for (auto& inds : list_matches){
87+
token = inds[0].first;
8688
for (int j = 0; j < max_count; j++){
8789
if (j < inds.size()){
8890
neighbors_indices[i0 * max_count + j] = inds[j].first;
8991
dists[i0 * max_count + j] = (float) inds[j].second;
92+
93+
9094
}
9195

9296
else {
93-
neighbors_indices[i0 * max_count + j] = -1;
94-
dists[i0 * max_count + j] = radius * radius;
97+
neighbors_indices[i0 * max_count + j] = token;
98+
dists[i0 * max_count + j] = -1;
9599
}
96100
}
97101
i0++;
@@ -239,8 +243,8 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
239243
dists[i0 * max_count + j] = (float) inds_dists[j].second;
240244
}
241245
else {
242-
neighbors_indices[i0 * max_count + j] = supports.size();
243-
dists[i0 * max_count + j] = radius * radius;
246+
neighbors_indices[i0 * max_count + j] = supports.size()/3;
247+
dists[i0 * max_count + j] = -1;
244248
}
245249

246250
}

cpu/src/torch_nearest_neighbors.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,23 @@ std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor query,
135135
}
136136
return std::make_pair(out.clone(), out_dists.clone());
137137
}
138+
139+
140+
std::pair<at::Tensor, at::Tensor> dense_ball_query(at::Tensor query,
141+
at::Tensor support,
142+
float radius, int max_num, int mode){
143+
144+
int b = query.size(0);
145+
vector<at::Tensor> batch_idx;
146+
vector<at::Tensor> batch_dist;
147+
for (int i=0; i < b; i++){
148+
149+
auto out_pair = ball_query(query[i], support[i], radius, max_num, mode);
150+
batch_idx.push_back(out_pair.first);
151+
batch_dist.push_back(out_pair.second);
152+
}
153+
auto out_idx = torch::stack(batch_idx);
154+
auto out_dist = torch::stack(batch_dist);
155+
return std::make_pair(out_idx, out_dist);
156+
157+
}

test/test_ballquerry.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@ def test_simple_cpu(self):
1919

2020
def test_cpu_gpu_equality(self):
2121
a = torch.randn(5, 1000, 3)
22-
npt.assert_array_equal(ball_query_dense(0.1, 17, a, a).detach().numpy(),
23-
ball_query_dense(0.1, 17, a.cuda(), a.cuda()).cpu().detach().numpy())
22+
res_cpu = ball_query_dense(0.1, 17, a, a).detach().numpy()
23+
res_cuda = ball_query_dense(0.1, 17, a.cuda(), a.cuda()).cpu().detach().numpy()
24+
for i in range(a.shape[0]):
25+
for j in range(a.shape[1]):
26+
# Because it is not necessary the same order
27+
assert set(res_cpu[i][j]) == set(res_cuda[i][j])
2428

2529

2630
if __name__ == "__main__":

torch_points/torchpoints.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -252,17 +252,10 @@ def forward(ctx, radius, nsample, xyz, new_xyz, batch_xyz=None, batch_new_xyz=No
252252
if new_xyz.is_cuda:
253253
return tpcuda.ball_query_dense(new_xyz, xyz, radius, nsample)
254254
else:
255-
b = xyz.size(0)
256-
npoints = new_xyz.size(1)
257-
n = xyz.size(1)
258-
batch_new_xyz = torch.arange(0, b, dtype=torch.long).repeat(npoints, 1).T.reshape(-1)
259-
batch_xyz = torch.arange(0, b, dtype=torch.long).repeat(n, 1).T.reshape(-1)
260-
ind, dist = tpcpu.batch_ball_query(new_xyz.view(-1, 3),
261-
xyz.view(-1, 3),
262-
batch_new_xyz,
263-
batch_xyz,
264-
radius, nsample)
265-
return ind.view(b, npoints, nsample)
255+
ind, dist = tpcpu.dense_ball_query(new_xyz,
256+
xyz,
257+
radius, nsample, mode=0)
258+
return ind
266259

267260
@staticmethod
268261
def backward(ctx, a=None):
@@ -294,12 +287,16 @@ class BallQueryPartialDense(Function):
294287
def forward(ctx, radius, nsample, x, y, batch_x, batch_y):
295288
# type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
296289
if x.is_cuda:
297-
return tpcuda.ball_query_partial_dense(x, y,
298-
batch_x,
299-
batch_y,
300-
radius, nsample)
290+
return tpcuda.ball_query_partial_dense(x, y,
291+
batch_x,
292+
batch_y,
293+
radius, nsample)
301294
else:
302-
raise NotImplementedError
295+
ind, dist = tpcpu.dense_ball_query(x, y,
296+
batch_x,
297+
batch_y,
298+
radius, nsample, mode=0)
299+
return ind, dist
303300

304301
@staticmethod
305302
def backward(ctx, a=None):
@@ -319,9 +316,9 @@ def ball_query_partial_dense(radius, nsample, x, y, batch_x, batch_y):
319316
y : torch.Tensor
320317
(N, npoint, 3) centers of the ball query
321318
batch_x : torch.Tensor
322-
(M, ) Contains indexes to indicate within batch it belongs to.
319+
(M, ) Contains indexes to indicate within batch it belongs to.
323320
batch_y : torch.Tensor
324-
(N, ) Contains indexes to indicate within batch it belongs to
321+
(N, ) Contains indexes to indicate within batch it belongs to
325322
326323
Returns
327324
-------
@@ -347,4 +344,4 @@ def ball_query(radius: float, nsample: int, x, y, batch_x=None, batch_y=None, mo
347344
raise Exception('batch_x and batch_y should not be provided')
348345
return ball_query_dense(radius, nsample, x, y)
349346
else:
350-
raise Exception('unrecognized mode {}'.format(mode))
347+
raise Exception('unrecognized mode {}'.format(mode))

0 commit comments

Comments
 (0)