Skip to content

Commit d64e307

Browse files
Add more comments
1 parent 71982c8 commit d64e307

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

torch_points_kernels/cluster.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ def _grow_proximity_core(neighbours, min_cluster_size):
3737

3838

3939
def grow_proximity(pos, batch, nsample=16, radius=0.02, min_cluster_size=32):
40-
""" Grow based on proximity only"""
40+
""" Grow based on proximity only
41+
Neighbour search is done on device while the cluster assignement is done on cpu"""
4142
assert pos.shape[0] == batch.shape[0]
4243
neighbours = (
4344
ball_query_partial_dense(radius, nsample, pos, pos, batch, batch)[0]
@@ -57,6 +58,18 @@ def region_grow(
5758
5859
Parameters
5960
----------
61+
pos: torch.Tensor [N, 3]
62+
Location of the points
63+
labels: torch.Tensor [N,]
64+
labels of each point
65+
ignore_labels:
66+
Labels that should be ignored, no region growing will be performed on those
67+
nsample:
68+
maximum number of neighbours to consider
69+
radius:
70+
radius for the neighbour search
71+
min_cluster_size:
72+
Number of points above which a cluster is considered valid
6073
"""
6174
assert labels.dim() == 1
6275
assert pos.dim() == 2
@@ -68,6 +81,8 @@ def region_grow(
6881
for l in unique_labels:
6982
if l in ignore_labels:
7083
continue
84+
85+
# Build clusters for a given label (ignore other points)
7186
label_mask = labels == l
7287
local_ind = ind[label_mask]
7388
label_clusters = grow_proximity(
@@ -77,6 +92,8 @@ def region_grow(
7792
radius=radius,
7893
min_cluster_size=min_cluster_size,
7994
)
95+
96+
# Remap indices to original coordinates
8097
if len(label_clusters):
8198
remaped_clusters = []
8299
for cluster in label_clusters:

torch_points_kernels/torchpoints.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import torch_points_kernels.points_cpu as tpcpu
88
from .knn import knn
9+
from .cluster import region_grow
910

1011
if torch.cuda.is_available():
1112
import torch_points_kernels.points_cuda as tpcuda

0 commit comments

Comments
 (0)