@@ -37,7 +37,8 @@ def _grow_proximity_core(neighbours, min_cluster_size):
3737
3838
3939def grow_proximity (pos , batch , nsample = 16 , radius = 0.02 , min_cluster_size = 32 ):
40- """ Grow based on proximity only"""
40+ """ Grow based on proximity only
41+ Neighbour search is done on device while the cluster assignement is done on cpu"""
4142 assert pos .shape [0 ] == batch .shape [0 ]
4243 neighbours = (
4344 ball_query_partial_dense (radius , nsample , pos , pos , batch , batch )[0 ]
@@ -57,6 +58,18 @@ def region_grow(
5758
5859 Parameters
5960 ----------
61+ pos: torch.Tensor [N, 3]
62+ Location of the points
63+ labels: torch.Tensor [N,]
64+ labels of each point
65+ ignore_labels:
66+ Labels that should be ignored, no region growing will be performed on those
67+ nsample:
68+ maximum number of neighbours to consider
69+ radius:
70+ radius for the neighbour search
71+ min_cluster_size:
72+ Number of points above which a cluster is considered valid
6073 """
6174 assert labels .dim () == 1
6275 assert pos .dim () == 2
@@ -68,6 +81,8 @@ def region_grow(
6881 for l in unique_labels :
6982 if l in ignore_labels :
7083 continue
84+
85+ # Build clusters for a given label (ignore other points)
7186 label_mask = labels == l
7287 local_ind = ind [label_mask ]
7388 label_clusters = grow_proximity (
@@ -77,6 +92,8 @@ def region_grow(
7792 radius = radius ,
7893 min_cluster_size = min_cluster_size ,
7994 )
95+
96+ # Remap indices to original coordinates
8097 if len (label_clusters ):
8198 remaped_clusters = []
8299 for cluster in label_clusters :
0 commit comments