Skip to content

Commit d40ab87

Browse files
committed
add branch detection
1 parent 658948e commit d40ab87

File tree

3 files changed

+311
-1
lines changed

3 files changed

+311
-1
lines changed

fast_hdbscan/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .hdbscan import HDBSCAN, fast_hdbscan
2+
from .branches import BranchDetector, find_branch_sub_clusters
23

34
# Force JIT compilation on import
45
import numpy as np
@@ -7,4 +8,4 @@
78
HDBSCAN(allow_single_cluster=True).fit(random_data)
89
HDBSCAN(cluster_selection_method="leaf").fit(random_data)
910

10-
__all__ = ["HDBSCAN", "fast_hdbscan"]
11+
__all__ = ["HDBSCAN", "fast_hdbscan", "BranchDetector", "find_branch_sub_clusters"]

fast_hdbscan/branches.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import numpy as np
2+
from .sub_clusters import SubClusterDetector, find_sub_clusters
3+
4+
5+
def compute_centrality(data, probabilities, *args):
6+
points = args[-1]
7+
cluster_data = data[points, :]
8+
centroid = np.average(cluster_data, weights=probabilities[points], axis=0)
9+
return 1 / np.linalg.norm(cluster_data - centroid[None, :], axis=1)
10+
11+
12+
def apply_branch_threshold(
13+
labels,
14+
branch_labels,
15+
probabilities,
16+
cluster_probabilities,
17+
cluster_points,
18+
linkage_trees,
19+
label_sides_as_branches=False,
20+
):
21+
running_id = 0
22+
min_branch_count = 1 if label_sides_as_branches else 2
23+
for pts, tree in zip(cluster_points, linkage_trees):
24+
unique_branch_labels = np.unique(branch_labels[pts])
25+
has_noise = int(unique_branch_labels[0] == -1)
26+
num_branches = len(unique_branch_labels) - has_noise
27+
if num_branches <= min_branch_count and tree is not None:
28+
labels[pts] = running_id
29+
probabilities[pts] = cluster_probabilities[pts]
30+
running_id += 1
31+
continue
32+
else:
33+
branch_labels[pts] = np.where(
34+
branch_labels[pts] < 0, num_branches, branch_labels[pts]
35+
)
36+
labels[pts] = branch_labels[pts] + running_id
37+
running_id += num_branches + has_noise
38+
39+
40+
def find_branch_sub_clusters(
41+
clusterer,
42+
cluster_labels=None,
43+
cluster_probabilities=None,
44+
*,
45+
min_branch_size=None,
46+
max_branch_size=None,
47+
allow_single_branch=None,
48+
branch_selection_method=None,
49+
branch_selection_epsilon=0.0,
50+
branch_selection_persistence=0.0,
51+
label_sides_as_branches=False,
52+
):
53+
result = find_sub_clusters(
54+
clusterer,
55+
cluster_labels,
56+
cluster_probabilities,
57+
lens_callback=compute_centrality,
58+
min_cluster_size=min_branch_size,
59+
max_cluster_size=max_branch_size,
60+
allow_single_cluster=allow_single_branch,
61+
cluster_selection_method=branch_selection_method,
62+
cluster_selection_epsilon=branch_selection_epsilon,
63+
cluster_selection_persistence=branch_selection_persistence,
64+
)
65+
apply_branch_threshold(
66+
result[0],
67+
result[4],
68+
result[1],
69+
result[3],
70+
result[-1],
71+
label_sides_as_branches=label_sides_as_branches,
72+
)
73+
return result
74+
75+
76+
class BranchDetector(SubClusterDetector):
77+
"""
78+
Performs a flare-detection post-processing step to detect branches within
79+
clusters [1]_.
80+
81+
For each cluster, a graph is constructed connecting the data points based on
82+
their mutual reachability distances. Each edge is given a centrality value
83+
based on how far it lies from the cluster's center. Then, the edges are
84+
clustered as if that centrality was a distance, progressively removing the
85+
'center' of each cluster and seeing how many branches remain.
86+
87+
References
88+
----------
89+
.. [1] Bot, D. M., Peeters, J., Liesenborgs J., & Aerts, J. (2023, November).
90+
FLASC: A Flare-Sensitive Clustering Algorithm: Extending HDBSCAN* for
91+
Detecting Branches in Clusters. arXiv:2311.15887.
92+
"""
93+
94+
def __init__(
95+
self,
96+
*,
97+
min_branch_size=None,
98+
max_branch_size=None,
99+
allow_single_branch=None,
100+
branch_selection_method=None,
101+
branch_selection_epsilon=0.0,
102+
branch_selection_persistence=0.0,
103+
label_sides_as_branches=False,
104+
):
105+
super().__init__(
106+
min_cluster_size=min_branch_size,
107+
max_cluster_size=max_branch_size,
108+
allow_single_cluster=allow_single_branch,
109+
cluster_selection_method=branch_selection_method,
110+
cluster_selection_epsilon=branch_selection_epsilon,
111+
cluster_selection_persistence=branch_selection_persistence,
112+
)
113+
self.label_sides_as_branches = label_sides_as_branches
114+
115+
def fit(self, clusterer, labels=None, probabilities=None):
116+
super().fit(clusterer, labels, probabilities, compute_centrality)
117+
apply_branch_threshold(
118+
self.labels_,
119+
self.sub_cluster_labels_,
120+
self.probabilities_,
121+
self.cluster_probabilities_,
122+
self.cluster_points_,
123+
self.linkage_trees_,
124+
label_sides_as_branches=self.label_sides_as_branches,
125+
)
126+
self.branch_labels_ = self.sub_cluster_labels_
127+
self.branch_probabilities_ = self.sub_cluster_probabilities_
128+
self.centralities_ = self.lens_values_
129+
return self
130+
131+
@property
132+
def approximation_graph_(self):
133+
"""See :class:`~hdbscan.plots.ApproximationGraph` for documentation."""
134+
return super()._make_approximation_graph(
135+
lens_name="centrality", sub_cluster_name="branch"
136+
)
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
"""
2+
Test for the branches module.
3+
"""
4+
5+
import pytest
6+
import numpy as np
7+
from sklearn.exceptions import NotFittedError
8+
from fast_hdbscan import HDBSCAN, BranchDetector, find_branch_sub_clusters
9+
10+
11+
def make_branches(points_per_branch=30):
12+
# Control points for line segments that merge three clusters
13+
p0 = (0.13, -0.26)
14+
p1 = (0.24, -0.12)
15+
p2 = (0.32, 0.1)
16+
p3 = (0.13, 0.1)
17+
18+
# Noisy points along lines between three clusters
19+
return np.concatenate(
20+
[
21+
np.column_stack(
22+
(
23+
np.linspace(p_start[0], p_end[0], points_per_branch),
24+
np.linspace(p_start[1], p_end[1], points_per_branch),
25+
)
26+
)
27+
+ np.random.normal(size=(points_per_branch, 2), scale=0.01)
28+
for p_start, p_end in [(p0, p1), (p1, p2), (p1, p3)]
29+
]
30+
)
31+
32+
33+
np.random.seed(0)
34+
X = np.concatenate(
35+
(
36+
make_branches(),
37+
make_branches()[:60] + np.array([0.3, 0]),
38+
)
39+
)
40+
c = HDBSCAN(min_samples=5, min_cluster_size=10).fit(X)
41+
42+
43+
def check_detected_groups(c, n_clusters=3, n_branches=6, overridden=False):
44+
"""Checks branch_detector output for main invariants."""
45+
noise_mask = c.labels_ == -1
46+
assert np.all(np.unique(c.labels_[~noise_mask]) == np.arange(n_branches))
47+
assert np.all(np.unique(c.cluster_labels_[~noise_mask]) == np.arange(n_clusters))
48+
assert (c.branch_labels_[noise_mask] == 0).all()
49+
assert (c.branch_probabilities_[noise_mask] == 1.0).all()
50+
assert (c.probabilities_[noise_mask] == 0.0).all()
51+
assert (c.cluster_probabilities_[noise_mask] == 0.0).all()
52+
if not overridden:
53+
assert len(c.cluster_points_) == n_clusters
54+
for condensed_tree, linkage_tree in zip(c._condensed_trees, c._linkage_trees):
55+
assert linkage_tree is not None
56+
assert condensed_tree is not None
57+
58+
59+
# --- Detecting Branches
60+
61+
62+
def test_attributes():
63+
def check_attributes():
64+
b = BranchDetector().fit(c)
65+
check_detected_groups(b, n_clusters=2, n_branches=5)
66+
assert len(b.linkage_trees_) == 2
67+
assert len(b.condensed_trees_) == 2
68+
assert isinstance(b.condensed_trees_[0], CondensedTree)
69+
assert isinstance(b.linkage_trees_[0], SingleLinkageTree)
70+
assert isinstance(b.approximation_graph_, ApproximationGraph)
71+
72+
try:
73+
from hdbscan.plots import ApproximationGraph, CondensedTree, SingleLinkageTree
74+
75+
check_attributes()
76+
except ImportError:
77+
pass
78+
79+
80+
def test_selection_method():
81+
b = BranchDetector(branch_selection_method="eom").fit(c)
82+
check_detected_groups(b, n_clusters=2, n_branches=5)
83+
84+
b = BranchDetector(branch_selection_method="leaf").fit(c)
85+
check_detected_groups(b, n_clusters=2, n_branches=5)
86+
87+
88+
def test_min_branch_size():
89+
b = BranchDetector(min_branch_size=7).fit(c)
90+
labels, counts = np.unique(b.labels_[b.branch_probabilities_ > 0], return_counts=True)
91+
assert (counts[labels >= 0] >= 7).all()
92+
check_detected_groups(b, n_clusters=2, n_branches=5)
93+
94+
95+
def test_label_sides_as_branches():
96+
b = BranchDetector(label_sides_as_branches=True).fit(c)
97+
check_detected_groups(b, n_clusters=2, n_branches=6)
98+
99+
100+
def test_max_branch_size():
101+
b = BranchDetector(label_sides_as_branches=True, max_branch_size=25).fit(c)
102+
check_detected_groups(b, n_clusters=2, n_branches=4)
103+
104+
105+
def test_override_cluster_labels():
106+
X_missing = X.copy()
107+
X_missing[60:80] = np.nan
108+
c = HDBSCAN(min_cluster_size=5).fit(X_missing)
109+
split_y = c.labels_.copy()
110+
split_y[split_y == 1] = 0
111+
split_y[split_y == 2] = 1
112+
b = BranchDetector(label_sides_as_branches=True).fit(c, split_y)
113+
check_detected_groups(b, n_clusters=2, n_branches=5, overridden=True)
114+
assert b._condensed_trees[0] is None
115+
assert b._linkage_trees[0] is None
116+
117+
118+
def test_allow_single_branch_with_filters():
119+
# Without persistence, find 6 branches
120+
b = BranchDetector(
121+
min_branch_size=5,
122+
branch_selection_method="leaf",
123+
).fit(c)
124+
unique_labels = np.unique(b.labels_)
125+
assert len(unique_labels) == 5
126+
127+
# Adding persistence removes the branches
128+
b = BranchDetector(
129+
min_branch_size=5,
130+
branch_selection_method="leaf",
131+
branch_selection_persistence=0.15,
132+
).fit(c)
133+
unique_labels = np.unique(b.labels_)
134+
assert len(unique_labels) == 2
135+
136+
# Adding epsilon removes some branches
137+
b = BranchDetector(
138+
min_branch_size=5,
139+
branch_selection_method="leaf",
140+
branch_selection_epsilon=1 / 0.002,
141+
).fit(c)
142+
unique_labels = np.unique(b.labels_)
143+
assert len(unique_labels) == 2
144+
145+
146+
def test_badargs():
147+
c_nofit = HDBSCAN(min_cluster_size=5)
148+
149+
with pytest.raises(TypeError):
150+
find_branch_sub_clusters("fail")
151+
with pytest.raises(TypeError):
152+
find_branch_sub_clusters(None)
153+
with pytest.raises(NotFittedError):
154+
find_branch_sub_clusters(c_nofit)
155+
with pytest.raises(ValueError):
156+
find_branch_sub_clusters(c, min_branch_size=-1)
157+
with pytest.raises(ValueError):
158+
find_branch_sub_clusters(c, min_branch_size=0)
159+
with pytest.raises(ValueError):
160+
find_branch_sub_clusters(c, min_branch_size=1)
161+
with pytest.raises(ValueError):
162+
find_branch_sub_clusters(c, min_branch_size=2.0)
163+
with pytest.raises(ValueError):
164+
find_branch_sub_clusters(c, min_branch_size="fail")
165+
with pytest.raises(ValueError):
166+
find_branch_sub_clusters(c, branch_selection_persistence=-0.1)
167+
with pytest.raises(ValueError):
168+
find_branch_sub_clusters(c, branch_selection_epsilon=-0.1)
169+
with pytest.raises(ValueError):
170+
find_branch_sub_clusters(
171+
c,
172+
branch_selection_method="something_else",
173+
)

0 commit comments

Comments
 (0)