Skip to content

Commit 70dff00

Browse files
committed
add propagate sub cluster label option
1 parent d40ab87 commit 70dff00

File tree

3 files changed

+62
-2
lines changed

3 files changed

+62
-2
lines changed

fast_hdbscan/branches.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def find_branch_sub_clusters(
4949
branch_selection_epsilon=0.0,
5050
branch_selection_persistence=0.0,
5151
label_sides_as_branches=False,
52+
propagate_labels=False,
5253
):
5354
result = find_sub_clusters(
5455
clusterer,
@@ -61,6 +62,7 @@ def find_branch_sub_clusters(
6162
cluster_selection_method=branch_selection_method,
6263
cluster_selection_epsilon=branch_selection_epsilon,
6364
cluster_selection_persistence=branch_selection_persistence,
65+
propagate_labels=propagate_labels,
6466
)
6567
apply_branch_threshold(
6668
result[0],
@@ -101,6 +103,7 @@ def __init__(
101103
branch_selection_epsilon=0.0,
102104
branch_selection_persistence=0.0,
103105
label_sides_as_branches=False,
106+
propagate_labels=False,
104107
):
105108
super().__init__(
106109
min_cluster_size=min_branch_size,
@@ -109,6 +112,7 @@ def __init__(
109112
cluster_selection_method=branch_selection_method,
110113
cluster_selection_epsilon=branch_selection_epsilon,
111114
cluster_selection_persistence=branch_selection_persistence,
115+
propagate_labels=propagate_labels,
112116
)
113117
self.label_sides_as_branches = label_sides_as_branches
114118

fast_hdbscan/sub_clusters.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
induced by those distances within the clusters found by HDBSCAN.
44
"""
55

6-
import numba
76
import numpy as np
87
from sklearn.utils.validation import check_is_fitted
98
from sklearn.base import BaseEstimator, ClusterMixin
@@ -123,6 +122,48 @@ def update_labels(
123122
return labels, probabilities, sub_labels, sub_probabilities, lens_values
124123

125124

125+
def propagate_sub_cluster_labels(labels, sub_labels, graph_list, points_list):
126+
running_id = 0
127+
for points, core_graph in zip(points_list, graph_list):
128+
# Skip clusters with no labelled branches
129+
unique_sub_labels = np.unique(sub_labels[points])
130+
if unique_sub_labels[0] != -1 or len(unique_sub_labels) == 1:
131+
continue
132+
133+
# Create undirected core graph
134+
undirected = [
135+
{np.int32(0): np.float32(0.0) for _ in range(0)} for _ in range(len(points))
136+
]
137+
for idx, (start, end) in enumerate(
138+
zip(core_graph.indptr, core_graph.indptr[1:])
139+
):
140+
for i in range(start, end):
141+
neighbor = core_graph.indices[i]
142+
undirected[idx][neighbor] = 1 / core_graph.weights[i]
143+
undirected[neighbor][idx] = 1 / core_graph.weights[i]
144+
145+
# Repeat density-weighted majority votes on noise points until all are assigned
146+
while True:
147+
noise_idx = np.nonzero(sub_labels[points] == -1)[0]
148+
if noise_idx.shape[0] == 0:
149+
break
150+
for idx in noise_idx:
151+
candidates = {np.int32(0): np.float32(0.0) for _ in range(0)}
152+
for neighbor_idx, weight in undirected[idx].items():
153+
label = sub_labels[points[neighbor_idx]]
154+
if label == -1:
155+
continue
156+
candidates[label] = candidates.get(label, 0.0) + weight
157+
158+
if len(candidates) == 0:
159+
continue
160+
sub_labels[points[idx]] = max(candidates.items(), key=lambda x: x[1])[0]
161+
162+
labels[points] = sub_labels[points] + running_id
163+
running_id += len(unique_sub_labels) - 1
164+
return labels, sub_labels
165+
166+
126167
def remap_results(
127168
labels,
128169
probabilities,
@@ -192,6 +233,7 @@ def find_sub_clusters(
192233
cluster_selection_method=None,
193234
cluster_selection_epsilon=0.0,
194235
cluster_selection_persistence=0.0,
236+
propagate_labels=False,
195237
):
196238
check_is_fitted(
197239
clusterer,
@@ -323,6 +365,12 @@ def find_sub_clusters(
323365
data.shape[0],
324366
)
325367

368+
# Propagate labels if requested
369+
if propagate_labels:
370+
labels, sub_labels = propagate_sub_cluster_labels(
371+
labels, sub_labels, core_graphs, points
372+
)
373+
326374
# Reset for infinite data points
327375
if last_outlier > 0:
328376
(
@@ -377,6 +425,7 @@ def __init__(
377425
cluster_selection_method="eom",
378426
cluster_selection_epsilon=0.0,
379427
cluster_selection_persistence=0.0,
428+
propagate_labels=False,
380429
):
381430
self.lens_values = lens_values
382431
self.min_cluster_size = min_cluster_size
@@ -385,6 +434,7 @@ def __init__(
385434
self.cluster_selection_method = cluster_selection_method
386435
self.cluster_selection_epsilon = cluster_selection_epsilon
387436
self.cluster_selection_persistence = cluster_selection_persistence
437+
self.propagate_labels = propagate_labels
388438

389439
def fit(self, clusterer, labels=None, probabilities=None, lens_callback=None):
390440
"""labels and probabilities override the clusterer's values."""
@@ -413,6 +463,7 @@ def fit(self, clusterer, labels=None, probabilities=None, lens_callback=None):
413463
cluster_selection_method=self.cluster_selection_method,
414464
cluster_selection_epsilon=self.cluster_selection_epsilon,
415465
cluster_selection_persistence=self.cluster_selection_persistence,
466+
propagate_labels=self.propagate_labels,
416467
)
417468
# also store the core distances and raw data for the member functions
418469
self._raw_data = clusterer._raw_data
@@ -458,7 +509,6 @@ def _make_approximation_graph(self, lens_name=None, sub_cluster_name=None):
458509
raw_data=self._raw_data,
459510
)
460511

461-
462512
@property
463513
def condensed_trees_(self):
464514
"""See :class:`~hdbscan.plots.CondensedTree` for documentation."""

fast_hdbscan/tests/test_sub_clusters.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,12 @@ def test_selection_method():
8484
check_detected_groups(b, n_clusters=2, n_subs=7)
8585

8686

87+
def test_label_propagation():
88+
b = SubClusterDetector(lens_values=centrality, propagate_labels=True).fit(c)
89+
assert np.all(b.sub_cluster_labels_ >= 0)
90+
check_detected_groups(b, n_clusters=2, n_subs=5)
91+
92+
8793
def test_min_cluster_size():
8894
b = SubClusterDetector(lens_values=centrality, min_cluster_size=7).fit(c)
8995
labels, counts = np.unique(b.labels_[b.sub_cluster_labels_ >= 0], return_counts=True)

0 commit comments

Comments
 (0)