Skip to content

Commit eb5f331

Browse files
committed
add entry point from mst edges
1 parent e6e6b5e commit eb5f331

File tree

3 files changed

+85
-40
lines changed

3 files changed

+85
-40
lines changed

fast_hdbscan/cluster_trees.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,10 @@ def unselect_below_node_bcubed(node, cluster_tree, selected_clusters, unselected
409409
unselected_nodes[child] = True
410410

411411
@numba.njit()
412-
def extract_clusters_bcubed(condensed_tree, cluster_tree, label_indices, allow_virtual_nodes=False, allow_single_cluster=False):
412+
def extract_clusters_bcubed(condensed_tree, cluster_tree, data_labels, allow_virtual_nodes=False, allow_single_cluster=False):
413+
label_indices = Dict()
414+
for index in np.flatnonzero(data_labels > -1):
415+
label_indices[index] = data_labels[index]
413416

414417
if allow_virtual_nodes:
415418

fast_hdbscan/hdbscan.py

Lines changed: 54 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
except ImportError:
3131
_HAVE_HDBSCAN = False
3232

33-
from numba.typed import Dict
3433

3534

3635
def to_numpy_rec_array(named_tuple_tree):
@@ -137,7 +136,7 @@ def fast_hdbscan(
137136
data,
138137
data_labels=None,
139138
semi_supervised=False,
140-
ss_algorithm=None,
139+
ss_algorithm='bc',
141140
min_samples=10,
142141
min_cluster_size=10,
143142
cluster_selection_method="eom",
@@ -149,17 +148,14 @@ def fast_hdbscan(
149148
):
150149
data = check_array(data)
151150

152-
if semi_supervised and data_labels is None:
153-
raise ValueError(
154-
"data_labels must not be None when semi_supervised is set to True!"
155-
)
156-
151+
# Detect parameter inconsistencies early.
157152
if semi_supervised:
158-
label_indices = np.flatnonzero(data_labels > -1)
159-
label_values = data_labels[label_indices]
160-
data_labels_dict = Dict()
161-
for index, label in zip(label_indices, label_values):
162-
data_labels_dict[index] = label
153+
if data_labels is None:
154+
raise ValueError(
155+
"data_labels must not be None when semi_supervised is set to True!"
156+
)
157+
if ss_algorithm not in ["bc", "bc_simple"]:
158+
raise ValueError(f"Invalid ss_algorithm {ss_algorithm}")
163159

164160
if (
165161
(not (np.issubdtype(type(min_samples), np.integer) or min_samples is None))
@@ -184,38 +180,61 @@ def fast_hdbscan(
184180
min_samples=min_cluster_size if min_samples is None else min_samples,
185181
sample_weights=sample_weights,
186182
)
183+
184+
return fast_hdbscan_mst_edges(
185+
edges,
186+
data_labels=data_labels,
187+
semi_supervised=semi_supervised,
188+
ss_algorithm=ss_algorithm,
189+
min_cluster_size=min_cluster_size,
190+
cluster_selection_method=cluster_selection_method,
191+
max_cluster_size=max_cluster_size,
192+
allow_single_cluster=allow_single_cluster,
193+
cluster_selection_epsilon=cluster_selection_epsilon,
194+
sample_weights=sample_weights,
195+
)[: (None if return_trees else 2)]
196+
197+
198+
def fast_hdbscan_mst_edges(
199+
edges,
200+
data_labels=None,
201+
semi_supervised=False,
202+
ss_algorithm='bc',
203+
min_cluster_size=10,
204+
cluster_selection_method="eom",
205+
max_cluster_size=np.inf,
206+
allow_single_cluster=False,
207+
cluster_selection_epsilon=0.0,
208+
sample_weights=None,
209+
):
187210
sorted_mst = edges[np.argsort(edges.T[2])]
188211
if sample_weights is None:
189212
linkage_tree = mst_to_linkage_tree(sorted_mst)
190213
else:
191214
linkage_tree = mst_to_linkage_tree_w_sample_weights(sorted_mst, sample_weights)
192-
condensed_tree = condense_tree(linkage_tree, min_cluster_size=min_cluster_size, sample_weights=sample_weights)
215+
condensed_tree = condense_tree(
216+
linkage_tree, min_cluster_size=min_cluster_size, sample_weights=sample_weights
217+
)
193218
if cluster_selection_epsilon > 0.0 or cluster_selection_method == "eom":
194219
cluster_tree = cluster_tree_from_condensed_tree(condensed_tree)
195220

196221
if cluster_selection_method == "eom":
197222
if semi_supervised:
198-
if ss_algorithm == "bc":
199-
selected_clusters = extract_clusters_bcubed(
200-
condensed_tree,
201-
cluster_tree,
202-
data_labels_dict,
203-
allow_virtual_nodes=True,
204-
allow_single_cluster=allow_single_cluster,
205-
)
206-
elif ss_algorithm == "bc_simple":
207-
selected_clusters = extract_clusters_bcubed(
208-
condensed_tree,
209-
cluster_tree,
210-
data_labels_dict,
211-
allow_virtual_nodes=False,
212-
allow_single_cluster=allow_single_cluster,
213-
)
214-
else:
215-
raise ValueError(f"Invalid ss_algorithm {ss_algorithm}")
223+
# Silently ignores max_cluster_size!
224+
# Assumes ss_algorithm is either 'bc' or 'bc_simple'
225+
selected_clusters = extract_clusters_bcubed(
226+
condensed_tree,
227+
cluster_tree,
228+
data_labels,
229+
allow_virtual_nodes=True if ss_algorithm == "bc" else False,
230+
allow_single_cluster=allow_single_cluster,
231+
)
216232
else:
217233
selected_clusters = extract_eom_clusters(
218-
condensed_tree, cluster_tree, max_cluster_size=max_cluster_size, allow_single_cluster=allow_single_cluster
234+
condensed_tree,
235+
cluster_tree,
236+
max_cluster_size=max_cluster_size,
237+
allow_single_cluster=allow_single_cluster,
219238
)
220239
elif cluster_selection_method == "leaf":
221240
selected_clusters = extract_leaves(
@@ -235,15 +254,13 @@ def fast_hdbscan(
235254
condensed_tree,
236255
selected_clusters,
237256
cluster_selection_epsilon,
238-
n_samples=data.shape[0],
257+
n_samples=edges.shape[0] + 1,
239258
)
240259
membership_strengths = get_point_membership_strength_vector(
241260
condensed_tree, selected_clusters, clusters
242261
)
243262

244-
if return_trees:
245-
return clusters, membership_strengths, linkage_tree, condensed_tree, sorted_mst
246-
return clusters, membership_strengths
263+
return clusters, membership_strengths, linkage_tree, condensed_tree, sorted_mst
247264

248265

249266
class HDBSCAN(BaseEstimator, ClusterMixin):
@@ -257,7 +274,7 @@ def __init__(
257274
max_cluster_size=np.inf,
258275
cluster_selection_epsilon=0.0,
259276
semi_supervised=False,
260-
ss_algorithm=None,
277+
ss_algorithm='bc',
261278
**kwargs,
262279
):
263280
self.min_cluster_size = min_cluster_size

fast_hdbscan/tests/test_hdbscan.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
HDBSCAN,
1616
fast_hdbscan,
1717
)
18+
from fast_hdbscan.hdbscan import fast_hdbscan_mst_edges
1819

1920
# from sklearn.cluster.tests.common import generate_clustered_data
2021
from sklearn.datasets import make_blobs
@@ -149,9 +150,13 @@ def test_hdbscan_badargs():
149150
fast_hdbscan(X, cluster_selection_epsilon=-0.1)
150151
with pytest.raises(ValueError):
151152
fast_hdbscan(X, cluster_selection_method="fail")
153+
with pytest.raises(ValueError):
154+
fast_hdbscan(X, semi_supervised=True, ss_algorithm="fail")
155+
with pytest.raises(ValueError):
156+
fast_hdbscan(X, semi_supervised=True, data_labels=None)
152157

153158

154-
def test_fhdbscan_allow_single_cluster_with_epsilon():
159+
def test_hdbscan_allow_single_cluster_with_epsilon():
155160
np.random.seed(0)
156161
no_structure = np.random.rand(150, 2)
157162
# without epsilon we should see 68 noise points and 8 labels
@@ -173,14 +178,34 @@ def test_fhdbscan_allow_single_cluster_with_epsilon():
173178
assert len(unique_labels) == 2
174179
assert counts[unique_labels == -1] == 2
175180

176-
def test_fhdbscan_max_cluster_size():
181+
def test_hdbscan_max_cluster_size():
177182
model = HDBSCAN(max_cluster_size=30).fit(X)
178183
assert len(set(model.labels_)) >= 3
179184
for label in set(model.labels_):
180185
if label != -1:
181186
assert np.sum(model.labels_ == label) <= 30
182187

183188

189+
def test_mst_entry():
190+
# Assumes default keyword arguments match between class and function
191+
model = HDBSCAN(min_cluster_size=5).fit(X)
192+
(
193+
labels,
194+
probabilities,
195+
linkage_tree,
196+
condensed_tree,
197+
sorted_mst
198+
) = fast_hdbscan_mst_edges(model._min_spanning_tree, min_cluster_size=5)
199+
assert np.all(model.labels_ == labels)
200+
assert np.allclose(model.probabilities_, probabilities)
201+
assert np.allclose(model._min_spanning_tree, sorted_mst)
202+
assert np.allclose(model._single_linkage_tree, linkage_tree)
203+
assert np.allclose(model._condensed_tree['parent'], condensed_tree.parent)
204+
assert np.allclose(model._condensed_tree['child'], condensed_tree.child)
205+
assert np.allclose(model._condensed_tree['lambda_val'], condensed_tree.lambda_val)
206+
assert np.allclose(model._condensed_tree['child_size'], condensed_tree.child_size)
207+
208+
184209
# Disable for now -- need to refactor to meet newer standards
185210
@pytest.mark.skip(reason="need to refactor to meet newer standards")
186211
def test_hdbscan_is_sklearn_estimator():

0 commit comments

Comments
 (0)