Skip to content

Commit 821ff98

Browse files
committed
add persistence threshold to hdbscan
1 parent 2e7e712 commit 821ff98

File tree

4 files changed

+204
-73
lines changed

4 files changed

+204
-73
lines changed

fast_hdbscan/boruvka.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,4 +296,4 @@ def parallel_boruvka(tree, min_samples=10, sample_weights=None):
296296
n_components = np.unique(point_components).shape[0]
297297

298298
edges[:, 2] = np.sqrt(edges.T[2])
299-
return edges
299+
return edges, neighbors[:, 1:], np.sqrt(core_distances)

fast_hdbscan/cluster_trees.py

Lines changed: 123 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def condense_tree(hierarchy, min_cluster_size=10, max_cluster_size=np.inf, sampl
212212
relabel[right] = parent_node
213213
idx = eliminate_branch(left, parent_node, lambda_value, parents, children, lambdas, sizes, idx, ignore,
214214
hierarchy, num_points)
215-
# Then we have a large left cluster and a small right cluster: relabel the left node; elimiate the right branch
215+
# Then we have a large left cluster and a small right cluster: relabel the left node; eliminate the right branch
216216
elif left_count >= min_cluster_size and right_count < min_cluster_size:
217217
relabel[left] = parent_node
218218
idx = eliminate_branch(right, parent_node, lambda_value, parents, children, lambdas, sizes, idx, ignore,
@@ -250,18 +250,11 @@ def condense_tree(hierarchy, min_cluster_size=10, max_cluster_size=np.inf, sampl
250250

251251

252252
@numba.njit()
253-
def extract_leaves(condensed_tree, allow_single_cluster=True):
254-
n_nodes = condensed_tree.parent.max() + 1
255-
n_points = condensed_tree.parent.min()
256-
leaf_indicator = np.ones(n_nodes, dtype=np.bool_)
257-
leaf_indicator[:n_points] = False
258-
259-
for parent, child_size in zip(condensed_tree.parent, condensed_tree.child_size):
260-
if child_size > 1:
261-
leaf_indicator[parent] = False
262-
263-
return np.nonzero(leaf_indicator)[0]
264-
253+
def extract_leaves(cluster_tree, n_points):
254+
n_nodes = cluster_tree.child.max() + 1
255+
leaf_indicator = np.ones(n_nodes - n_points, dtype=np.bool_)
256+
leaf_indicator[cluster_tree.parent - n_points] = False
257+
return np.nonzero(leaf_indicator)[0] + n_points
265258

266259

267260
# The *_bcubed functions below implement the (semi-supervised) HDBSCAN*(BC) algorithm presented
@@ -448,7 +441,6 @@ def extract_clusters_bcubed(condensed_tree, cluster_tree, data_labels, allow_vir
448441
return np.asarray([node for node, selected in selected_clusters.items() if (selected and (node not in virtual_nodes))])
449442

450443

451-
452444
@numba.njit()
453445
def score_condensed_tree_nodes(condensed_tree):
454446
result = {0: np.float32(0.0) for i in range(0)}
@@ -472,9 +464,17 @@ def score_condensed_tree_nodes(condensed_tree):
472464

473465
@numba.njit()
474466
def cluster_tree_from_condensed_tree(condensed_tree):
475-
mask = condensed_tree.child_size > 1
476-
return CondensedTree(condensed_tree.parent[mask], condensed_tree.child[mask], condensed_tree.lambda_val[mask],
477-
condensed_tree.child_size[mask])
467+
return mask_condensed_tree(condensed_tree, condensed_tree.child_size > 1)
468+
469+
470+
@numba.njit()
471+
def mask_condensed_tree(condensed_tree, mask):
472+
return CondensedTree(
473+
condensed_tree.parent[mask],
474+
condensed_tree.child[mask],
475+
condensed_tree.lambda_val[mask],
476+
condensed_tree.child_size[mask]
477+
)
478478

479479

480480
@numba.njit()
@@ -529,61 +529,136 @@ def extract_eom_clusters(condensed_tree, cluster_tree, max_cluster_size=np.inf,
529529

530530

531531
@numba.njit()
532-
def cluster_epsilon_search(clusters, cluster_tree, min_persistence=0.0):
532+
def simplify_hierarchy(condensed_tree, n_points, persistence_threshold):
533+
keep_mask = np.ones(condensed_tree.parent.shape[0], dtype=np.bool_)
534+
cluster_tree = cluster_tree_from_condensed_tree(condensed_tree)
535+
536+
processed = {np.int64(0)}
537+
processed.clear()
538+
while cluster_tree.parent.shape[0] > 0:
539+
leaves = set(extract_leaves(cluster_tree, n_points))
540+
births = max_lambdas(condensed_tree, leaves)
541+
deaths = min_lambdas(cluster_tree, leaves)
542+
543+
cluster_mask = np.ones(cluster_tree.parent.shape[0], dtype=np.bool_)
544+
for leaf in sorted(leaves, reverse=True):
545+
if leaf in processed or (births[leaf] - deaths[leaf]) >= persistence_threshold:
546+
continue
547+
548+
# Find rows for leaf and sibling
549+
leaf_idx = np.searchsorted(cluster_tree.child, leaf)
550+
parent = cluster_tree.parent[leaf_idx]
551+
if leaf_idx > 0 and cluster_tree.parent[leaf_idx - 1] == parent:
552+
sibling_idx = leaf_idx - 1
553+
else:
554+
sibling_idx = leaf_idx + 1
555+
sibling = cluster_tree.child[sibling_idx]
556+
557+
# Update parent values to the new parent
558+
for idx, row in enumerate(cluster_tree.parent):
559+
if row in [leaf, sibling]:
560+
cluster_tree.parent[idx] = parent
561+
for idx, row in enumerate(condensed_tree.parent):
562+
if row in [leaf, sibling]:
563+
condensed_tree.parent[idx] = parent
564+
condensed_tree.lambda_val[idx] = deaths[leaf]
565+
566+
# Mark visited rows
567+
processed.add(leaf)
568+
processed.add(sibling)
569+
cluster_mask[leaf_idx] = False
570+
cluster_mask[sibling_idx] = False
571+
for idx, child in enumerate(condensed_tree.child):
572+
if child in [leaf, sibling]:
573+
keep_mask[idx] = False
574+
575+
if np.all(cluster_mask):
576+
break
577+
cluster_tree = mask_condensed_tree(cluster_tree, cluster_mask)
578+
579+
condensed_tree = mask_condensed_tree(condensed_tree, keep_mask)
580+
return remap_cluster_ids(condensed_tree, n_points)
581+
582+
583+
@numba.njit()
584+
def remap_cluster_ids(condensed_tree, n_points):
585+
n_nodes = condensed_tree.parent.max() + 1
586+
remaining_parents = np.unique(condensed_tree.parent)
587+
id_map = np.empty(n_nodes - n_points, dtype=np.int64)
588+
id_map[remaining_parents - n_points] = np.arange(
589+
n_points, n_points + remaining_parents.shape[0]
590+
)
591+
for column in [condensed_tree.parent, condensed_tree.child]:
592+
for idx, node in enumerate(column):
593+
if node >= n_points:
594+
column[idx] = id_map[node - n_points]
595+
return condensed_tree
596+
597+
598+
@numba.njit()
599+
def cluster_epsilon_search(clusters, cluster_tree, min_epsilon=0.0):
533600
selected = list()
534601
# only way to create a typed empty set
535602
processed = {np.int64(0)}
536603
processed.clear()
537604

605+
# cluster_tree is sorted with increasing children
606+
# prepare to use binary search on parent in segment_in_branches
607+
parent_order = np.argsort(cluster_tree.parent)
608+
parents = cluster_tree.parent[parent_order]
609+
children = cluster_tree.child[parent_order]
610+
538611
root = cluster_tree.parent.min()
539612
for cluster in clusters:
540-
eps = 1 / cluster_tree.lambda_val[cluster_tree.child == cluster][0]
541-
if eps < min_persistence:
613+
idx = np.searchsorted(cluster_tree.child, cluster)
614+
death_eps = 1 / cluster_tree.lambda_val[idx]
615+
if death_eps < min_epsilon:
542616
if cluster not in processed:
543-
parent = traverse_upwards(cluster_tree, min_persistence, root, cluster)
617+
parent = traverse_upwards(cluster_tree, min_epsilon, root, cluster)
544618
selected.append(parent)
545-
processed |= segments_in_branch(cluster_tree, parent)
619+
processed |= segments_in_branch(parents, children, parent)
546620
else:
547621
selected.append(cluster)
548622
return np.asarray(selected)
549623

550624

551625
@numba.njit()
552-
def traverse_upwards(cluster_tree, min_persistence, root, segment):
626+
def traverse_upwards(cluster_tree, min_epsilon, root, segment):
553627
parent = cluster_tree.parent[cluster_tree.child == segment][0]
554628
if parent == root:
555629
return root
556-
parent_eps = 1 / cluster_tree.lambda_val[cluster_tree.child == parent][0]
557-
if parent_eps >= min_persistence:
630+
parent_death_eps = 1 / cluster_tree.lambda_val[cluster_tree.child == parent][0]
631+
if parent_death_eps >= min_epsilon:
558632
return parent
559633
else:
560-
return traverse_upwards(cluster_tree, min_persistence, root, parent)
634+
return traverse_upwards(cluster_tree, min_epsilon, root, parent)
561635

562636

563637
@numba.njit()
564-
def segments_in_branch(cluster_tree, segment):
638+
def segments_in_branch(parents, children, segment):
565639
# only way to create a typed empty set
566-
result = {np.intp(0)}
640+
child_set = {np.int64(0)}
641+
result = {np.int64(0)}
567642
result.clear()
568643
to_process = {segment}
569644

570645
while len(to_process) > 0:
571646
result |= to_process
572-
to_process = set(cluster_tree.child[
573-
in_set_parallel(cluster_tree.parent, to_process)
574-
])
647+
648+
child_set.clear()
649+
for segment in to_process:
650+
idx = np.searchsorted(parents, segment)
651+
if idx >= len(parents):
652+
continue
653+
child_set.add(children[idx])
654+
child_set.add(children[idx + 1])
655+
656+
to_process.clear()
657+
to_process |= child_set
575658

576659
return result
577660

578661

579-
@numba.njit(parallel=True)
580-
def in_set_parallel(values, targets):
581-
mask = np.empty(values.shape[0], dtype=numba.boolean)
582-
for i in numba.prange(values.shape[0]):
583-
mask[i] = values[i] in targets
584-
return mask
585-
586-
587662
@numba.njit(parallel=True)
588663
def get_cluster_labelling_at_cut(linkage_tree, cut, min_cluster_size):
589664

@@ -628,7 +703,7 @@ def get_cluster_label_vector(
628703
cluster_selection_epsilon,
629704
n_samples,
630705
):
631-
if len(tree.parent) == 0:
706+
if len(tree.parent) == 0 or len(clusters) == 0:
632707
return np.full(n_samples, -1, dtype=np.intp)
633708
root_cluster = tree.parent.min()
634709
result = np.full(n_samples, -1, dtype=np.intp)
@@ -680,6 +755,14 @@ def max_lambdas(tree, clusters):
680755
return result
681756

682757

758+
@numba.njit()
759+
def min_lambdas(cluster_tree, clusters):
760+
return {
761+
c: cluster_tree.lambda_val[np.searchsorted(cluster_tree.child, c)]
762+
for c in clusters
763+
}
764+
765+
683766
@numba.njit()
684767
def get_point_membership_strength_vector(tree, clusters, labels):
685768
result = np.zeros(labels.shape[0], dtype=np.float32)

0 commit comments

Comments
 (0)