Skip to content

Commit 809c35a

Browse files
committed
Fix prediction data not honoring cluster_selection_epsilon
1 parent e55f957 commit 809c35a

File tree

5 files changed

+44
-23
lines changed

5 files changed

+44
-23
lines changed

hdbscan/_hdbscan_tree.pyx

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,9 @@ cpdef tuple get_clusters(np.ndarray tree, dict stability,
705705
706706
stabilities : ndarray (n_clusters,)
707707
The cluster coherence strengths of each cluster.
708+
709+
selected clusters : ndarray (n_clusters,)
710+
The ids of the selected clusters
708711
"""
709712
cdef list node_list
710713
cdef np.ndarray cluster_tree
@@ -803,4 +806,4 @@ cpdef tuple get_clusters(np.ndarray tree, dict stability,
803806
probs = get_probabilities(tree, reverse_cluster_map, labels)
804807
stabilities = get_stability_scores(labels, clusters, stability, max_lambda)
805808

806-
return (labels, probs, stabilities)
809+
return (labels, probs, stabilities, np.array(sorted(clusters)))

hdbscan/flat.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,8 @@ def HDBSCAN_flat(X, n_clusters=None,
184184
new_clusterer.probabilities_,
185185
new_clusterer.cluster_persistence_,
186186
new_clusterer._condensed_tree,
187-
new_clusterer._single_linkage_tree) = output
187+
new_clusterer._single_linkage_tree,
188+
new_clusterer._selected_clusters) = output
188189

189190
# PredictionData attached to HDBSCAN should also change.
190191
# A function re_init is defined in this module to handle this.

hdbscan/hdbscan_.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def _tree_to_labels(
6262
"""
6363
condensed_tree = condense_tree(single_linkage_tree, min_cluster_size)
6464
stability_dict = compute_stability(condensed_tree)
65-
labels, probabilities, stabilities = get_clusters(
65+
labels, probabilities, stabilities, selected_clusters = get_clusters(
6666
condensed_tree,
6767
stability_dict,
6868
cluster_selection_method,
@@ -72,7 +72,8 @@ def _tree_to_labels(
7272
max_cluster_size,
7373
)
7474

75-
return (labels, probabilities, stabilities, condensed_tree, single_linkage_tree)
75+
return (labels, probabilities, stabilities, condensed_tree, single_linkage_tree,
76+
selected_clusters)
7677

7778

7879
def _hdbscan_generic(
@@ -1130,6 +1131,7 @@ def __init__(
11301131
self._outlier_scores = None
11311132
self._prediction_data = None
11321133
self._relative_validity = None
1134+
self._selected_clusters = None
11331135

11341136
def fit(self, X, y=None):
11351137
"""Perform HDBSCAN clustering from features or distance matrix.
@@ -1186,6 +1188,7 @@ def fit(self, X, y=None):
11861188
self.cluster_persistence_,
11871189
self._condensed_tree,
11881190
self._single_linkage_tree,
1191+
self._selected_clusters,
11891192
self._min_spanning_tree,
11901193
) = hdbscan(clean_data, **kwargs)
11911194

@@ -1248,6 +1251,7 @@ def generate_prediction_data(self):
12481251
self._prediction_data = PredictionData(
12491252
self._raw_data,
12501253
self.condensed_tree_,
1254+
self._selected_clusters,
12511255
min_samples,
12521256
tree_type=tree_type,
12531257
metric=self.metric,

hdbscan/prediction.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,14 @@ def _recurse_leaf_dfs(self, current_node):
9595
return sum(
9696
[recurse_leaf_dfs(self.cluster_tree, child) for child in children], [])
9797

98-
def __init__(self, data, condensed_tree, min_samples,
98+
def __init__(self, data, condensed_tree, selected_clusters, min_samples,
9999
tree_type='kdtree', metric='euclidean', **kwargs):
100100
self.raw_data = data.astype(np.float64)
101101
self.tree = self._tree_type_map[tree_type](self.raw_data,
102102
metric=metric, **kwargs)
103103
self.core_distances = self.tree.query(data, k=min_samples)[0][:, -1]
104104
self.dist_metric = DistanceMetric.get_metric(metric, **kwargs)
105105

106-
selected_clusters = sorted(condensed_tree._select_clusters())
107106
# raw_condensed_tree = condensed_tree.to_numpy()
108107
raw_condensed_tree = condensed_tree._raw_tree
109108

hdbscan/tests/test_hdbscan.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def test_hdbscan_distance_matrix():
144144
D = distance.squareform(distance.pdist(X))
145145
D /= np.max(D)
146146

147-
labels, p, persist, ctree, ltree, mtree = hdbscan(D, metric="precomputed")
147+
labels, p, persist, ctree, ltree, selclstrs, mtree = hdbscan(D, metric="precomputed")
148148
# number of clusters, ignoring noise if present
149149
n_clusters_1 = len(set(labels)) - int(-1 in labels) # ignore noise
150150
assert n_clusters_1 == n_clusters
@@ -167,7 +167,7 @@ def test_hdbscan_sparse_distance_matrix():
167167
D = sparse.csr_matrix(D)
168168
D.eliminate_zeros()
169169

170-
labels, p, persist, ctree, ltree, mtree = hdbscan(D, metric="precomputed")
170+
labels, p, persist, ctree, ltree, selclstrs, mtree = hdbscan(D, metric="precomputed")
171171
# number of clusters, ignoring noise if present
172172
n_clusters_1 = len(set(labels)) - int(-1 in labels) # ignore noise
173173
assert n_clusters_1 == n_clusters
@@ -178,7 +178,7 @@ def test_hdbscan_sparse_distance_matrix():
178178

179179

180180
def test_hdbscan_feature_vector():
181-
labels, p, persist, ctree, ltree, mtree = hdbscan(X)
181+
labels, p, persist, ctree, ltree, selclstrs, mtree = hdbscan(X)
182182
n_clusters_1 = len(set(labels)) - int(-1 in labels)
183183
assert n_clusters_1 == n_clusters
184184

@@ -191,7 +191,9 @@ def test_hdbscan_feature_vector():
191191

192192

193193
def test_hdbscan_prims_kdtree():
194-
labels, p, persist, ctree, ltree, mtree = hdbscan(X, algorithm="prims_kdtree")
194+
labels, p, persist, ctree, ltree, selclstrs, mtree = hdbscan(
195+
X, algorithm="prims_kdtree"
196+
)
195197
n_clusters_1 = len(set(labels)) - int(-1 in labels)
196198
assert n_clusters_1 == n_clusters
197199

@@ -203,7 +205,9 @@ def test_hdbscan_prims_kdtree():
203205

204206

205207
def test_hdbscan_prims_balltree():
206-
labels, p, persist, ctree, ltree, mtree = hdbscan(X, algorithm="prims_balltree")
208+
labels, p, persist, ctree, ltree, selclstrs, mtree = hdbscan(
209+
X, algorithm="prims_balltree"
210+
)
207211
n_clusters_1 = len(set(labels)) - int(-1 in labels)
208212
assert n_clusters_1 == n_clusters
209213

@@ -215,7 +219,9 @@ def test_hdbscan_prims_balltree():
215219

216220

217221
def test_hdbscan_boruvka_kdtree():
218-
labels, p, persist, ctree, ltree, mtree = hdbscan(X, algorithm="boruvka_kdtree")
222+
labels, p, persist, ctree, ltree, selclstrs, mtree, = hdbscan(
223+
X, algorithm="boruvka_kdtree"
224+
)
219225
n_clusters_1 = len(set(labels)) - int(-1 in labels)
220226
assert n_clusters_1 == n_clusters
221227

@@ -229,7 +235,9 @@ def test_hdbscan_boruvka_kdtree():
229235

230236

231237
def test_hdbscan_boruvka_balltree():
232-
labels, p, persist, ctree, ltree, mtree = hdbscan(X, algorithm="boruvka_balltree")
238+
labels, p, persist, ctree, ltree, selclstrs, mtree = hdbscan(
239+
X, algorithm="boruvka_balltree"
240+
)
233241
n_clusters_1 = len(set(labels)) - int(-1 in labels)
234242
assert n_clusters_1 == n_clusters
235243

@@ -243,7 +251,7 @@ def test_hdbscan_boruvka_balltree():
243251

244252

245253
def test_hdbscan_generic():
246-
labels, p, persist, ctree, ltree, mtree = hdbscan(X, algorithm="generic")
254+
labels, p, persist, ctree, ltree, selclstrs, mtree = hdbscan(X, algorithm="generic")
247255
n_clusters_1 = len(set(labels)) - int(-1 in labels)
248256
assert n_clusters_1 == n_clusters
249257

@@ -261,7 +269,7 @@ def test_hdbscan_high_dimensional():
261269
H, y = make_blobs(n_samples=50, random_state=0, n_features=64)
262270
# H, y = shuffle(X, y, random_state=7)
263271
H = StandardScaler().fit_transform(H)
264-
labels, p, persist, ctree, ltree, mtree = hdbscan(H)
272+
labels, p, persist, ctree, ltree, selclstrs, mtree = hdbscan(H)
265273
n_clusters_1 = len(set(labels)) - int(-1 in labels)
266274
assert n_clusters_1 == n_clusters
267275

@@ -275,7 +283,7 @@ def test_hdbscan_high_dimensional():
275283

276284

277285
def test_hdbscan_best_balltree_metric():
278-
labels, p, persist, ctree, ltree, mtree = hdbscan(
286+
labels, p, persist, ctree, ltree, selclstrs, mtree = hdbscan(
279287
X, metric="seuclidean", V=np.ones(X.shape[1])
280288
)
281289
n_clusters_1 = len(set(labels)) - int(-1 in labels)
@@ -287,7 +295,9 @@ def test_hdbscan_best_balltree_metric():
287295

288296

289297
def test_hdbscan_no_clusters():
290-
labels, p, persist, ctree, ltree, mtree = hdbscan(X, min_cluster_size=len(X) + 1)
298+
labels, p, persist, ctree, ltree, selclstrs, mtree = hdbscan(
299+
X, min_cluster_size=len(X) + 1
300+
)
291301
n_clusters_1 = len(set(labels)) - int(-1 in labels)
292302
assert n_clusters_1 == 0
293303

@@ -298,7 +308,7 @@ def test_hdbscan_no_clusters():
298308

299309
def test_hdbscan_min_cluster_size():
300310
for min_cluster_size in range(2, len(X) + 1, 1):
301-
labels, p, persist, ctree, ltree, mtree = hdbscan(
311+
labels, p, persist, ctree, ltree, selclstrs, mtree = hdbscan(
302312
X, min_cluster_size=min_cluster_size
303313
)
304314
true_labels = [label for label in labels if label != -1]
@@ -315,7 +325,7 @@ def test_hdbscan_callable_metric():
315325
# metric is the function reference, not the string key.
316326
metric = distance.euclidean
317327

318-
labels, p, persist, ctree, ltree, mtree = hdbscan(X, metric=metric)
328+
labels, p, persist, ctree, ltree, selclstrs, mtree = hdbscan(X, metric=metric)
319329
n_clusters_1 = len(set(labels)) - int(-1 in labels)
320330
assert n_clusters_1 == n_clusters
321331

@@ -333,8 +343,10 @@ def test_hdbscan_boruvka_kdtree_matches():
333343

334344
data = generate_noisy_data()
335345

336-
labels_prims, p, persist, ctree, ltree, mtree = hdbscan(data, algorithm="generic")
337-
labels_boruvka, p, persist, ctree, ltree, mtree = hdbscan(
346+
labels_prims, p, persist, ctree, ltree, selclstrs, mtree = hdbscan(
347+
data, algorithm="generic"
348+
)
349+
labels_boruvka, p, persist, ctree, ltree, selclstrs, mtree = hdbscan(
338350
data, algorithm="boruvka_kdtree"
339351
)
340352

@@ -354,8 +366,10 @@ def test_hdbscan_boruvka_balltree_matches():
354366

355367
data = generate_noisy_data()
356368

357-
labels_prims, p, persist, ctree, ltree, mtree = hdbscan(data, algorithm="generic")
358-
labels_boruvka, p, persist, ctree, ltree, mtree = hdbscan(
369+
labels_prims, p, persist, ctree, ltree, selclstrs, mtree = hdbscan(
370+
data, algorithm="generic"
371+
)
372+
labels_boruvka, p, persist, ctree, ltree, selclstrs, mtree = hdbscan(
359373
data, algorithm="boruvka_balltree"
360374
)
361375

0 commit comments

Comments
 (0)