Skip to content

Commit 916b1cd

Browse files
committed
Further hdbscan coverage improvements
1 parent 81cc4ce commit 916b1cd

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

hdbscan/hdbscan_.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def _hdbscan_generic(X, min_samples=5, alpha=1.0,
7575
if gen_min_span_tree:
7676
result_min_span_tree = min_spanning_tree.copy()
7777
for index, row in enumerate(result_min_span_tree[1:], 1):
78-
candidates = np.where(np.isclose(mutual_reachability_[row[1]], row[2]))[0]
78+
candidates = np.where(np.isclose(mutual_reachability_[int(row[1])], row[2]))[0]
7979
candidates = np.intersect1d(candidates, min_spanning_tree[:index, :2].astype(int))
8080
candidates = candidates[candidates != row[1]]
8181
assert (len(candidates) > 0)

hdbscan/tests/test_hdbscan.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def test_hdbscan_prims_balltree():
157157
metric='cosine')
158158

159159
def test_hdbscan_boruvka_kdtree():
160-
labels, p, persist, ctree, ltree, mtree = hdbscan(X, algorithm='boruvka_kdtree')
160+
labels, p, persist, ctree, ltree, mtree = hdbscan(X, algorithm='boruvka_kdtree', leaf_size=5)
161161
n_clusters_1 = len(set(labels)) - int(-1 in labels)
162162
assert_equal(n_clusters_1, n_clusters)
163163

@@ -172,7 +172,7 @@ def test_hdbscan_boruvka_kdtree():
172172
metric='russelrao')
173173

174174
def test_hdbscan_boruvka_balltree():
175-
labels, p, persist, ctree, ltree, mtree = hdbscan(X, algorithm='boruvka_balltree')
175+
labels, p, persist, ctree, ltree, mtree = hdbscan(X, algorithm='boruvka_balltree', leaf_size=5)
176176
n_clusters_1 = len(set(labels)) - int(-1 in labels)
177177
assert_equal(n_clusters_1, n_clusters)
178178

@@ -186,6 +186,15 @@ def test_hdbscan_boruvka_balltree():
186186
algorithm='boruvka_balltree',
187187
metric='cosine')
188188

189+
def test_hdbscan_generic():
190+
labels, p, persist, ctree, ltree, mtree = hdbscan(X, algorithm='generic')
191+
n_clusters_1 = len(set(labels)) - int(-1 in labels)
192+
assert_equal(n_clusters_1, n_clusters)
193+
194+
labels = HDBSCAN(algorithm='generic', gen_min_span_tree=True).fit(X).labels_
195+
n_clusters_2 = len(set(labels)) - int(-1 in labels)
196+
assert_equal(n_clusters_2, n_clusters)
197+
189198

190199
def test_hdbscan_high_dimensional():
191200
H, y = make_blobs(n_samples=50, random_state=0, n_features=64)

0 commit comments

Comments
 (0)