Skip to content

Commit 2f83774

Browse files
authored
Merge pull request TutteInstitute#32 from JelmerBot/dev/branches
add branch detection functionality
2 parents 5ca7be0 + 1c5ff02 commit 2f83774

14 files changed

+2749
-106
lines changed

doc/detecting_branches.ipynb

Lines changed: 585 additions & 0 deletions
Large diffs are not rendered by default.

doc/for_developers.ipynb

Lines changed: 445 additions & 0 deletions
Large diffs are not rendered by default.

doc/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ User Guide
7676
basic_usage
7777
benchmarks
7878
comparable_clusterings
79+
detecting_branches
80+
for_developers
7981

8082

8183
----------

fast_hdbscan/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .hdbscan import HDBSCAN, fast_hdbscan
2+
from .branches import BranchDetector, find_branch_sub_clusters
23

34
# Force JIT compilation on import
45
import numpy as np
@@ -7,4 +8,4 @@
78
HDBSCAN(allow_single_cluster=True).fit(random_data)
89
HDBSCAN(cluster_selection_method="leaf").fit(random_data)
910

10-
__all__ = ["HDBSCAN", "fast_hdbscan"]
11+
__all__ = ["HDBSCAN", "fast_hdbscan", "BranchDetector", "find_branch_sub_clusters"]

fast_hdbscan/boruvka.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,33 @@
44
from .disjoint_set import ds_rank_create, ds_find, ds_union_by_rank
55
from .numba_kdtree import parallel_tree_query, rdist, point_to_node_lower_bound_rdist
66

7-
@numba.njit(locals={"i": numba.types.int64})
8-
def merge_components(disjoint_set, candidate_neighbors, candidate_neighbor_distances, point_components):
9-
component_edges = {np.int64(0): (np.int64(0), np.int64(1), np.float32(0.0)) for i in range(0)}
7+
8+
@numba.njit(locals={"parent": numba.types.int32})
9+
def select_components(candidate_distances, candidate_neighbors, point_components):
10+
component_edges = {np.int64(0): (np.int32(0), np.int32(1), np.float32(0.0)) for i in range(0)}
1011

1112
# Find the best edges from each component
12-
for i in range(candidate_neighbors.shape[0]):
13-
from_component = np.int64(point_components[i])
13+
for parent, (distance, neighbor, from_component) in enumerate(
14+
zip(candidate_distances, candidate_neighbors, point_components)
15+
):
1416
if from_component in component_edges:
15-
if candidate_neighbor_distances[i] < component_edges[from_component][2]:
16-
component_edges[from_component] = (np.int64(i), np.int64(candidate_neighbors[i]), candidate_neighbor_distances[i])
17+
if distance < component_edges[from_component][2]:
18+
component_edges[from_component] = (parent, neighbor, distance)
1719
else:
18-
component_edges[from_component] = (np.int64(i), np.int64(candidate_neighbors[i]), candidate_neighbor_distances[i])
20+
component_edges[from_component] = (parent, neighbor, distance)
21+
22+
return component_edges
23+
1924

25+
@numba.njit()
26+
def merge_components(disjoint_set, component_edges):
2027
result = np.empty((len(component_edges), 3), dtype=np.float64)
2128
result_idx = 0
2229

2330
# Add the best edges to the edge set and merge the relevant components
2431
for edge in component_edges.values():
25-
from_component = ds_find(disjoint_set, np.int32(edge[0]))
26-
to_component = ds_find(disjoint_set, np.int32(edge[1]))
32+
from_component = ds_find(disjoint_set, edge[0])
33+
to_component = ds_find(disjoint_set, edge[1])
2734
if from_component != to_component:
2835
result[result_idx] = (np.float64(edge[0]), np.float64(edge[1]), np.float64(edge[2]))
2936
result_idx += 1
@@ -34,10 +41,13 @@ def merge_components(disjoint_set, candidate_neighbors, candidate_neighbor_dista
3441

3542

3643
@numba.njit(parallel=True)
37-
def update_component_vectors(tree, disjoint_set, node_components, point_components):
44+
def update_point_components(disjoint_set, point_components):
3845
for i in numba.prange(point_components.shape[0]):
3946
point_components[i] = ds_find(disjoint_set, np.int32(i))
4047

48+
49+
@numba.njit()
50+
def update_node_components(tree, node_components, point_components):
4151
for i in range(tree.node_data.shape[0] - 1, -1, -1):
4252
node_info = tree.node_data[i]
4353

@@ -272,28 +282,28 @@ def parallel_boruvka(tree, min_samples=10, sample_weights=None):
272282
expected_neighbors = min_samples / mean_sample_weight
273283
distances, neighbors = parallel_tree_query(tree, tree.data, k=int(2 * expected_neighbors))
274284
core_distances = sample_weight_core_distance(distances, neighbors, sample_weights, min_samples)
275-
edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set)
276-
update_component_vectors(tree, components_disjoint_set, node_components, point_components)
277285
else:
278286
if min_samples > 1:
279287
distances, neighbors = parallel_tree_query(tree, tree.data, k=min_samples + 1, output_rdist=True)
280288
core_distances = distances.T[-1]
281-
edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set)
282-
update_component_vectors(tree, components_disjoint_set, node_components, point_components)
283289
else:
284290
core_distances = np.zeros(tree.data.shape[0], dtype=np.float32)
285291
distances, neighbors = parallel_tree_query(tree, tree.data, k=2)
286-
edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set)
287-
update_component_vectors(tree, components_disjoint_set, node_components, point_components)
288292

289-
while n_components > 1:
293+
edges = [np.empty((0, 3), dtype=np.float64) for _ in range(0)]
294+
new_edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set)
295+
while True:
296+
edges.append(new_edges)
297+
n_components -= new_edges.shape[0]
298+
if n_components == 1:
299+
break
300+
update_point_components(components_disjoint_set, point_components)
301+
update_node_components(tree, node_components, point_components)
290302
candidate_distances, candidate_indices = boruvka_tree_query(tree, node_components, point_components,
291303
core_distances)
292-
new_edges = merge_components(components_disjoint_set, candidate_indices, candidate_distances, point_components)
293-
update_component_vectors(tree, components_disjoint_set, node_components, point_components)
294-
295-
edges = np.vstack((edges, new_edges))
296-
n_components = np.unique(point_components).shape[0]
304+
component_edges = select_components(candidate_distances, candidate_indices, point_components)
305+
new_edges = merge_components(components_disjoint_set, component_edges)
297306

307+
edges = np.vstack(edges)
298308
edges[:, 2] = np.sqrt(edges.T[2])
299-
return edges
309+
return edges, neighbors[:, 1:], np.sqrt(core_distances)

fast_hdbscan/branches.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import numpy as np
2+
from .sub_clusters import SubClusterDetector, find_sub_clusters
3+
4+
5+
def compute_centrality(data, probabilities, *args):
6+
points = args[-1]
7+
cluster_data = data[points, :]
8+
centroid = np.average(cluster_data, weights=probabilities[points], axis=0)
9+
return 1 / np.linalg.norm(cluster_data - centroid[None, :], axis=1)
10+
11+
12+
def apply_branch_threshold(
13+
labels,
14+
branch_labels,
15+
probabilities,
16+
cluster_probabilities,
17+
cluster_points,
18+
linkage_trees,
19+
label_sides_as_branches=False,
20+
):
21+
running_id = 0
22+
min_branch_count = 1 if label_sides_as_branches else 2
23+
for pts, tree in zip(cluster_points, linkage_trees):
24+
unique_branch_labels = np.unique(branch_labels[pts])
25+
has_noise = int(unique_branch_labels[0] == -1)
26+
num_branches = len(unique_branch_labels) - has_noise
27+
if num_branches <= min_branch_count and tree is not None:
28+
labels[pts] = running_id
29+
probabilities[pts] = cluster_probabilities[pts]
30+
running_id += 1
31+
else:
32+
labels[pts] = branch_labels[pts] + has_noise + running_id
33+
running_id += num_branches + has_noise
34+
35+
36+
def find_branch_sub_clusters(
37+
clusterer,
38+
cluster_labels=None,
39+
cluster_probabilities=None,
40+
label_sides_as_branches=False,
41+
min_cluster_size=None,
42+
max_cluster_size=None,
43+
allow_single_cluster=None,
44+
cluster_selection_method=None,
45+
cluster_selection_epsilon=0.0,
46+
cluster_selection_persistence=0.0,
47+
):
48+
result = find_sub_clusters(
49+
clusterer,
50+
cluster_labels,
51+
cluster_probabilities,
52+
lens_callback=compute_centrality,
53+
min_cluster_size=min_cluster_size,
54+
max_cluster_size=max_cluster_size,
55+
allow_single_cluster=allow_single_cluster,
56+
cluster_selection_method=cluster_selection_method,
57+
cluster_selection_epsilon=cluster_selection_epsilon,
58+
cluster_selection_persistence=cluster_selection_persistence,
59+
)
60+
apply_branch_threshold(
61+
result[0],
62+
result[4],
63+
result[1],
64+
result[3],
65+
result[-1],
66+
label_sides_as_branches=label_sides_as_branches,
67+
)
68+
return result
69+
70+
71+
class BranchDetector(SubClusterDetector):
72+
"""
73+
Performs a flare-detection post-processing step to detect branches within
74+
clusters [1]_.
75+
76+
For each cluster, a graph is constructed connecting the data points based on
77+
their mutual reachability distances. Each edge is given a centrality value
78+
based on how far it lies from the cluster's center. Then, the edges are
79+
clustered as if that centrality was a distance, progressively removing the
80+
'center' of each cluster and seeing how many branches remain.
81+
82+
References
83+
----------
84+
.. [1] Bot, D. M., Peeters, J., Liesenborgs J., & Aerts, J. (2023, November).
85+
FLASC: A Flare-Sensitive Clustering Algorithm: Extending HDBSCAN* for
86+
Detecting Branches in Clusters. arXiv:2311.15887.
87+
"""
88+
89+
def __init__(
90+
self,
91+
min_cluster_size=None,
92+
max_cluster_size=None,
93+
allow_single_cluster=None,
94+
cluster_selection_method=None,
95+
cluster_selection_epsilon=0.0,
96+
cluster_selection_persistence=0.0,
97+
propagate_labels=False,
98+
label_sides_as_branches=False,
99+
):
100+
super().__init__(
101+
min_cluster_size=min_cluster_size,
102+
max_cluster_size=max_cluster_size,
103+
allow_single_cluster=allow_single_cluster,
104+
cluster_selection_method=cluster_selection_method,
105+
cluster_selection_epsilon=cluster_selection_epsilon,
106+
cluster_selection_persistence=cluster_selection_persistence,
107+
propagate_labels=propagate_labels,
108+
)
109+
self.label_sides_as_branches = label_sides_as_branches
110+
111+
def fit(self, clusterer, labels=None, probabilities=None, sample_weight=None):
112+
super().fit(clusterer, labels, probabilities, sample_weight, compute_centrality)
113+
apply_branch_threshold(
114+
self.labels_,
115+
self.sub_cluster_labels_,
116+
self.probabilities_,
117+
self.cluster_probabilities_,
118+
self.cluster_points_,
119+
self._linkage_trees,
120+
label_sides_as_branches=self.label_sides_as_branches,
121+
)
122+
self.branch_labels_ = self.sub_cluster_labels_
123+
self.branch_probabilities_ = self.sub_cluster_probabilities_
124+
self.centralities_ = self.lens_values_
125+
return self
126+
127+
def propagated_labels(self, label_sides_as_branches=None):
128+
if label_sides_as_branches is None:
129+
label_sides_as_branches = self.label_sides_as_branches
130+
131+
labels, branch_labels = super().propagated_labels()
132+
apply_branch_threshold(
133+
labels,
134+
branch_labels,
135+
np.zeros_like(self.probabilities_),
136+
np.zeros_like(self.probabilities_),
137+
self.cluster_points_,
138+
self._linkage_trees,
139+
label_sides_as_branches=label_sides_as_branches,
140+
)
141+
return labels, branch_labels
142+
143+
@property
144+
def approximation_graph_(self):
145+
"""See :class:`~hdbscan.plots.ApproximationGraph` for documentation."""
146+
return super()._make_approximation_graph(
147+
lens_name="centrality", sub_cluster_name="branch"
148+
)

0 commit comments

Comments
 (0)