Skip to content

Commit 5d48141

Browse files
committed
Fix type mergeing issues with linkage_merge_data
1 parent 147b489 commit 5d48141

File tree

2 files changed

+37
-6
lines changed

2 files changed

+37
-6
lines changed

fast_hdbscan/cluster_trees.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,41 @@ def linkage_merge_join(linkage_merge, left, right):
5353

5454

5555
@numba.njit()
56-
def mst_to_linkage_tree(sorted_mst, sample_weights=None):
56+
def mst_to_linkage_tree(sorted_mst):
5757
result = np.empty((sorted_mst.shape[0], sorted_mst.shape[1] + 1))
5858

5959
n_samples = sorted_mst.shape[0] + 1
60-
if sample_weights is None:
61-
linkage_merge = create_linkage_merge_data(n_samples)
62-
else:
63-
linkage_merge = create_linkage_merge_data_w_sample_weights(sample_weights)
60+
linkage_merge = create_linkage_merge_data(n_samples)
61+
62+
for index in range(sorted_mst.shape[0]):
63+
64+
left = np.intp(sorted_mst[index, 0])
65+
right = np.intp(sorted_mst[index, 1])
66+
delta = sorted_mst[index, 2]
67+
68+
left_component = linkage_merge_find(linkage_merge, left)
69+
right_component = linkage_merge_find(linkage_merge, right)
70+
71+
if left_component > right_component:
72+
result[index][0] = left_component
73+
result[index][1] = right_component
74+
else:
75+
result[index][1] = left_component
76+
result[index][0] = right_component
77+
78+
result[index][2] = delta
79+
result[index][3] = linkage_merge.size[left_component] + linkage_merge.size[right_component]
80+
81+
linkage_merge_join(linkage_merge, left_component, right_component)
82+
83+
return result
84+
85+
86+
@numba.njit()
87+
def mst_to_linkage_tree_w_sample_weights(sorted_mst, sample_weights):
88+
result = np.empty((sorted_mst.shape[0], sorted_mst.shape[1] + 1))
89+
90+
linkage_merge = create_linkage_merge_data_w_sample_weights(sample_weights)
6491

6592
for index in range(sorted_mst.shape[0]):
6693

fast_hdbscan/hdbscan.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .boruvka import parallel_boruvka
1212
from .cluster_trees import (
1313
mst_to_linkage_tree,
14+
mst_to_linkage_tree_w_sample_weights,
1415
condense_tree,
1516
extract_eom_clusters,
1617
extract_leaves,
@@ -161,7 +162,10 @@ def fast_hdbscan(
161162
numba_tree, min_samples=min_cluster_size if min_samples is None else min_samples, sample_weights=sample_weights
162163
)
163164
sorted_mst = edges[np.argsort(edges.T[2])]
164-
linkage_tree = mst_to_linkage_tree(sorted_mst, sample_weights=sample_weights)
165+
if sample_weights is None:
166+
linkage_tree = mst_to_linkage_tree(sorted_mst)
167+
else:
168+
linkage_tree = mst_to_linkage_tree_w_sample_weights(sorted_mst, sample_weights)
165169
condensed_tree = condense_tree(linkage_tree, min_cluster_size=min_cluster_size)
166170
if cluster_selection_epsilon > 0.0 or cluster_selection_method == "eom":
167171
cluster_tree = cluster_tree_from_condensed_tree(condensed_tree)

0 commit comments

Comments
 (0)