Skip to content

Commit 71982c8

Browse files
Basic region growing clustering
1 parent 8cc7b07 commit 71982c8

File tree

4 files changed

+208
-9
lines changed

4 files changed

+208
-9
lines changed

benchmark/region_cluster.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import unittest
2+
import torch
3+
import os
4+
import sys
5+
import time
6+
import random
7+
8+
ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")
9+
sys.path.insert(0, ROOT)
10+
11+
from torch_points_kernels.cluster import grow_proximity
12+
13+
torch.manual_seed(0)
14+
15+
num_points = 100000
16+
pos1 = torch.rand((num_points, 3))
17+
pos2 = torch.rand((num_points, 3)) + 2
18+
pos3 = torch.rand((num_points, 3)) + 4
19+
labels1 = torch.ones(num_points).long()
20+
labels2 = torch.ones(num_points).long()
21+
labels3 = torch.ones(num_points).long()
22+
pos = torch.cat([pos1, pos2, pos3], 0)
23+
label = torch.cat([labels1, labels2, labels3], 0)
24+
batch = torch.ones((3 * num_points)).long()
25+
cl = grow_proximity(pos, batch, radius=0.5)
26+
27+
28+
import cProfile, pstats, io
29+
from pstats import SortKey
30+
31+
pr = cProfile.Profile()
32+
pr.enable()
33+
t_start = time.perf_counter()
34+
grow_proximity(pos, batch)
35+
print(time.perf_counter() - t_start)
36+
pr.disable()
37+
s = io.StringIO()
38+
sortby = SortKey.CUMULATIVE
39+
ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
40+
ps.print_stats()
41+
print(s.getvalue())

test/test_cluster.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import unittest
2+
import torch
3+
import os
4+
import sys
5+
6+
ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")
7+
sys.path.insert(0, ROOT)
8+
9+
from torch_points_kernels.cluster import grow_proximity, region_grow
10+
11+
12+
class TestGrow(unittest.TestCase):
13+
def setUp(self):
14+
self.pos = torch.tensor(
15+
[
16+
[0, 0, 0],
17+
[1, 0, 0],
18+
[2, 0, 0],
19+
[10, 0, 0],
20+
[0, 0, 0],
21+
[1, 0, 0],
22+
[2, 0, 0],
23+
[10, 0, 0],
24+
]
25+
)
26+
self.batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1])
27+
self.labels = torch.tensor([0, 0, 1, 1, 0, 1, 1, 10])
28+
29+
def test_simple(self):
30+
clusters = grow_proximity(self.pos, self.batch, radius=2, min_cluster_size=1)
31+
self.assertEqual(clusters, [[0, 1, 2], [3], [4, 5, 6], [7]])
32+
33+
clusters = grow_proximity(self.pos, self.batch, radius=2, min_cluster_size=3)
34+
self.assertEqual(clusters, [[0, 1, 2], [4, 5, 6]])
35+
36+
def test_region_grow(self):
37+
clusters = region_grow(
38+
self.pos, self.labels, self.batch, radius=2, min_cluster_size=1
39+
)
40+
self.assertEqual(len(clusters[0]), 2)
41+
self.assertEqual(len(clusters[1]), 3)
42+
self.assertEqual(len(clusters[10]), 1)
43+
torch.testing.assert_allclose(clusters[0][0], torch.tensor([0, 1]))
44+
torch.testing.assert_allclose(clusters[0][1], torch.tensor([4]))
45+
torch.testing.assert_allclose(clusters[1][0], torch.tensor([2]))
46+
torch.testing.assert_allclose(clusters[1][1], torch.tensor([3]))
47+
torch.testing.assert_allclose(clusters[1][2], torch.tensor([5, 6]))
48+
torch.testing.assert_allclose(clusters[10][0], torch.tensor([7]))
49+
50+
51+
if __name__ == "__main__":
52+
unittest.main()

torch_points_kernels/cluster.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import torch
2+
from .torchpoints import ball_query_partial_dense
3+
import numpy as np
4+
import numba
5+
6+
7+
@numba.jit(nopython=True)
8+
def _grow_proximity_core(neighbours, min_cluster_size):
9+
num_points = int(neighbours.shape[0])
10+
visited = np.zeros((num_points,), dtype=numba.types.bool_)
11+
clusters = []
12+
for i in range(num_points):
13+
if visited[i]:
14+
continue
15+
16+
cluster = []
17+
queue = []
18+
visited[i] = True
19+
queue.append(i)
20+
cluster.append(i)
21+
while len(queue):
22+
k = queue.pop()
23+
k_neighbours = neighbours[k]
24+
for nei in k_neighbours:
25+
if nei.item() == -1:
26+
break
27+
28+
if not visited[nei]:
29+
visited[nei] = True
30+
queue.append(nei.item())
31+
cluster.append(nei.item())
32+
33+
if len(cluster) >= min_cluster_size:
34+
clusters.append(cluster)
35+
36+
return clusters
37+
38+
39+
def grow_proximity(pos, batch, nsample=16, radius=0.02, min_cluster_size=32):
40+
""" Grow based on proximity only"""
41+
assert pos.shape[0] == batch.shape[0]
42+
neighbours = (
43+
ball_query_partial_dense(radius, nsample, pos, pos, batch, batch)[0]
44+
.cpu()
45+
.numpy()
46+
)
47+
return _grow_proximity_core(neighbours, min_cluster_size)
48+
49+
50+
def region_grow(
51+
pos, labels, batch, ignore_labels=[], nsample=16, radius=0.02, min_cluster_size=32
52+
):
53+
""" Region growing clustering algorithm proposed in
54+
PointGroup: Dual-Set Point Grouping for 3D Instance Segmentation
55+
https://arxiv.org/pdf/2004.01658.pdf
56+
for instance segmentation
57+
58+
Parameters
59+
----------
60+
"""
61+
assert labels.dim() == 1
62+
assert pos.dim() == 2
63+
assert pos.shape[0] == labels.shape[0]
64+
65+
unique_labels = torch.unique(labels)
66+
clusters = {}
67+
ind = torch.arange(0, pos.shape[0])
68+
for l in unique_labels:
69+
if l in ignore_labels:
70+
continue
71+
label_mask = labels == l
72+
local_ind = ind[label_mask]
73+
label_clusters = grow_proximity(
74+
pos[label_mask, :],
75+
batch[label_mask],
76+
nsample=nsample,
77+
radius=radius,
78+
min_cluster_size=min_cluster_size,
79+
)
80+
if len(label_clusters):
81+
remaped_clusters = []
82+
for cluster in label_clusters:
83+
cluster = cluster.to(pos.device)
84+
remaped_clusters.append(local_ind[cluster])
85+
clusters[l.item()] = remaped_clusters
86+
87+
return clusters

torch_points_kernels/torchpoints.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ def furthest_point_sample(xyz, npoint):
3030
(B, npoint) tensor containing the set
3131
"""
3232
if npoint > xyz.shape[1]:
33-
raise ValueError("caanot sample %i points from an input set of %i points" % (npoint, xyz.shape[1]))
33+
raise ValueError(
34+
"caanot sample %i points from an input set of %i points"
35+
% (npoint, xyz.shape[1])
36+
)
3437
if xyz.is_cuda:
3538
return tpcuda.furthest_point_sampling(xyz, npoint)
3639
else:
@@ -99,9 +102,13 @@ def backward(ctx, grad_out):
99102
idx, weight, m = ctx.three_interpolate_for_backward
100103

101104
if grad_out.is_cuda:
102-
grad_features = tpcuda.three_interpolate_grad(grad_out.contiguous(), idx, weight, m)
105+
grad_features = tpcuda.three_interpolate_grad(
106+
grad_out.contiguous(), idx, weight, m
107+
)
103108
else:
104-
grad_features = tpcpu.knn_interpolate_grad(grad_out.contiguous(), idx, weight, m)
109+
grad_features = tpcpu.knn_interpolate_grad(
110+
grad_out.contiguous(), idx, weight, m
111+
)
105112

106113
return grad_features, None, None
107114

@@ -143,17 +150,23 @@ def grouping_operation(features, idx):
143150
all_idx = idx.reshape(idx.shape[0], -1)
144151
all_idx = all_idx.unsqueeze(1).repeat(1, features.shape[1], 1)
145152
grouped_features = features.gather(2, all_idx)
146-
return grouped_features.reshape(idx.shape[0], features.shape[1], idx.shape[1], idx.shape[2])
153+
return grouped_features.reshape(
154+
idx.shape[0], features.shape[1], idx.shape[1], idx.shape[2]
155+
)
147156

148157

149-
def ball_query_dense(radius, nsample, xyz, new_xyz, batch_xyz=None, batch_new_xyz=None, sort=False):
158+
def ball_query_dense(
159+
radius, nsample, xyz, new_xyz, batch_xyz=None, batch_new_xyz=None, sort=False
160+
):
150161
# type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
151162
if new_xyz.is_cuda:
152163
if sort:
153164
raise NotImplementedError("CUDA version does not sort the neighbors")
154165
ind, dist = tpcuda.ball_query_dense(new_xyz, xyz, radius, nsample)
155166
else:
156-
ind, dist = tpcpu.dense_ball_query(new_xyz, xyz, radius, nsample, mode=0, sorted=sort)
167+
ind, dist = tpcpu.dense_ball_query(
168+
new_xyz, xyz, radius, nsample, mode=0, sorted=sort
169+
)
157170
return ind, dist
158171

159172

@@ -162,9 +175,13 @@ def ball_query_partial_dense(radius, nsample, x, y, batch_x, batch_y, sort=False
162175
if x.is_cuda:
163176
if sort:
164177
raise NotImplementedError("CUDA version does not sort the neighbors")
165-
ind, dist = tpcuda.ball_query_partial_dense(x, y, batch_x, batch_y, radius, nsample)
178+
ind, dist = tpcuda.ball_query_partial_dense(
179+
x, y, batch_x, batch_y, radius, nsample
180+
)
166181
else:
167-
ind, dist = tpcpu.batch_ball_query(x, y, batch_x, batch_y, radius, nsample, mode=0, sorted=sort)
182+
ind, dist = tpcpu.batch_ball_query(
183+
x, y, batch_x, batch_y, radius, nsample, mode=0, sorted=sort
184+
)
168185
return ind, dist
169186

170187

@@ -207,7 +224,9 @@ def ball_query(
207224
assert x.size(0) == batch_x.size(0)
208225
assert y.size(0) == batch_y.size(0)
209226
assert x.dim() == 2
210-
return ball_query_partial_dense(radius, nsample, x, y, batch_x, batch_y, sort=sort)
227+
return ball_query_partial_dense(
228+
radius, nsample, x, y, batch_x, batch_y, sort=sort
229+
)
211230

212231
elif mode.lower() == "dense":
213232
if (batch_x is not None) or (batch_y is not None):

0 commit comments

Comments
 (0)