@@ -212,7 +212,7 @@ def condense_tree(hierarchy, min_cluster_size=10, max_cluster_size=np.inf, sampl
212212 relabel [right ] = parent_node
213213 idx = eliminate_branch (left , parent_node , lambda_value , parents , children , lambdas , sizes , idx , ignore ,
214214 hierarchy , num_points )
215- # Then we have a large left cluster and a small right cluster: relabel the left node; elimiate the right branch
215+ # Then we have a large left cluster and a small right cluster: relabel the left node; eliminate the right branch
216216 elif left_count >= min_cluster_size and right_count < min_cluster_size :
217217 relabel [left ] = parent_node
218218 idx = eliminate_branch (right , parent_node , lambda_value , parents , children , lambdas , sizes , idx , ignore ,
@@ -250,18 +250,11 @@ def condense_tree(hierarchy, min_cluster_size=10, max_cluster_size=np.inf, sampl
250250
251251
252252@numba .njit ()
253- def extract_leaves (condensed_tree , allow_single_cluster = True ):
254- n_nodes = condensed_tree .parent .max () + 1
255- n_points = condensed_tree .parent .min ()
256- leaf_indicator = np .ones (n_nodes , dtype = np .bool_ )
257- leaf_indicator [:n_points ] = False
258-
259- for parent , child_size in zip (condensed_tree .parent , condensed_tree .child_size ):
260- if child_size > 1 :
261- leaf_indicator [parent ] = False
262-
263- return np .nonzero (leaf_indicator )[0 ]
264-
253+ def extract_leaves (cluster_tree , n_points ):
254+ n_nodes = cluster_tree .child .max () + 1
255+ leaf_indicator = np .ones (n_nodes - n_points , dtype = np .bool_ )
256+ leaf_indicator [cluster_tree .parent - n_points ] = False
257+ return np .nonzero (leaf_indicator )[0 ] + n_points
265258
266259
267260# The *_bcubed functions below implement the (semi-supervised) HDBSCAN*(BC) algorithm presented
@@ -448,7 +441,6 @@ def extract_clusters_bcubed(condensed_tree, cluster_tree, data_labels, allow_vir
448441 return np .asarray ([node for node , selected in selected_clusters .items () if (selected and (node not in virtual_nodes ))])
449442
450443
451-
452444@numba .njit ()
453445def score_condensed_tree_nodes (condensed_tree ):
454446 result = {0 : np .float32 (0.0 ) for i in range (0 )}
@@ -472,9 +464,17 @@ def score_condensed_tree_nodes(condensed_tree):
472464
473465@numba .njit ()
474466def cluster_tree_from_condensed_tree (condensed_tree ):
475- mask = condensed_tree .child_size > 1
476- return CondensedTree (condensed_tree .parent [mask ], condensed_tree .child [mask ], condensed_tree .lambda_val [mask ],
477- condensed_tree .child_size [mask ])
467+ return mask_condensed_tree (condensed_tree , condensed_tree .child_size > 1 )
468+
469+
470+ @numba .njit ()
471+ def mask_condensed_tree (condensed_tree , mask ):
472+ return CondensedTree (
473+ condensed_tree .parent [mask ],
474+ condensed_tree .child [mask ],
475+ condensed_tree .lambda_val [mask ],
476+ condensed_tree .child_size [mask ]
477+ )
478478
479479
480480@numba .njit ()
@@ -529,61 +529,136 @@ def extract_eom_clusters(condensed_tree, cluster_tree, max_cluster_size=np.inf,
529529
530530
531531@numba .njit ()
532- def cluster_epsilon_search (clusters , cluster_tree , min_persistence = 0.0 ):
532+ def simplify_hierarchy (condensed_tree , n_points , persistence_threshold ):
533+ keep_mask = np .ones (condensed_tree .parent .shape [0 ], dtype = np .bool_ )
534+ cluster_tree = cluster_tree_from_condensed_tree (condensed_tree )
535+
536+ processed = {np .int64 (0 )}
537+ processed .clear ()
538+ while cluster_tree .parent .shape [0 ] > 0 :
539+ leaves = set (extract_leaves (cluster_tree , n_points ))
540+ births = max_lambdas (condensed_tree , leaves )
541+ deaths = min_lambdas (cluster_tree , leaves )
542+
543+ cluster_mask = np .ones (cluster_tree .parent .shape [0 ], dtype = np .bool_ )
544+ for leaf in sorted (leaves , reverse = True ):
545+ if leaf in processed or (births [leaf ] - deaths [leaf ]) >= persistence_threshold :
546+ continue
547+
548+ # Find rows for leaf and sibling
549+ leaf_idx = np .searchsorted (cluster_tree .child , leaf )
550+ parent = cluster_tree .parent [leaf_idx ]
551+ if leaf_idx > 0 and cluster_tree .parent [leaf_idx - 1 ] == parent :
552+ sibling_idx = leaf_idx - 1
553+ else :
554+ sibling_idx = leaf_idx + 1
555+ sibling = cluster_tree .child [sibling_idx ]
556+
557+ # Update parent values to the new parent
558+ for idx , row in enumerate (cluster_tree .parent ):
559+ if row in [leaf , sibling ]:
560+ cluster_tree .parent [idx ] = parent
561+ for idx , row in enumerate (condensed_tree .parent ):
562+ if row in [leaf , sibling ]:
563+ condensed_tree .parent [idx ] = parent
564+ condensed_tree .lambda_val [idx ] = deaths [leaf ]
565+
566+ # Mark visited rows
567+ processed .add (leaf )
568+ processed .add (sibling )
569+ cluster_mask [leaf_idx ] = False
570+ cluster_mask [sibling_idx ] = False
571+ for idx , child in enumerate (condensed_tree .child ):
572+ if child in [leaf , sibling ]:
573+ keep_mask [idx ] = False
574+
575+ if np .all (cluster_mask ):
576+ break
577+ cluster_tree = mask_condensed_tree (cluster_tree , cluster_mask )
578+
579+ condensed_tree = mask_condensed_tree (condensed_tree , keep_mask )
580+ return remap_cluster_ids (condensed_tree , n_points )
581+
582+
583+ @numba .njit ()
584+ def remap_cluster_ids (condensed_tree , n_points ):
585+ n_nodes = condensed_tree .parent .max () + 1
586+ remaining_parents = np .unique (condensed_tree .parent )
587+ id_map = np .empty (n_nodes - n_points , dtype = np .int64 )
588+ id_map [remaining_parents - n_points ] = np .arange (
589+ n_points , n_points + remaining_parents .shape [0 ]
590+ )
591+ for column in [condensed_tree .parent , condensed_tree .child ]:
592+ for idx , node in enumerate (column ):
593+ if node >= n_points :
594+ column [idx ] = id_map [node - n_points ]
595+ return condensed_tree
596+
597+
598+ @numba .njit ()
599+ def cluster_epsilon_search (clusters , cluster_tree , min_epsilon = 0.0 ):
533600 selected = list ()
534601 # only way to create a typed empty set
535602 processed = {np .int64 (0 )}
536603 processed .clear ()
537604
605+ # cluster_tree is sorted with increasing children
606+ # prepare to use binary search on parent in segment_in_branches
607+ parent_order = np .argsort (cluster_tree .parent )
608+ parents = cluster_tree .parent [parent_order ]
609+ children = cluster_tree .child [parent_order ]
610+
538611 root = cluster_tree .parent .min ()
539612 for cluster in clusters :
540- eps = 1 / cluster_tree .lambda_val [cluster_tree .child == cluster ][0 ]
541- if eps < min_persistence :
613+ idx = np .searchsorted (cluster_tree .child , cluster )
614+ death_eps = 1 / cluster_tree .lambda_val [idx ]
615+ if death_eps < min_epsilon :
542616 if cluster not in processed :
543- parent = traverse_upwards (cluster_tree , min_persistence , root , cluster )
617+ parent = traverse_upwards (cluster_tree , min_epsilon , root , cluster )
544618 selected .append (parent )
545- processed |= segments_in_branch (cluster_tree , parent )
619+ processed |= segments_in_branch (parents , children , parent )
546620 else :
547621 selected .append (cluster )
548622 return np .asarray (selected )
549623
550624
551625@numba .njit ()
552- def traverse_upwards (cluster_tree , min_persistence , root , segment ):
626+ def traverse_upwards (cluster_tree , min_epsilon , root , segment ):
553627 parent = cluster_tree .parent [cluster_tree .child == segment ][0 ]
554628 if parent == root :
555629 return root
556- parent_eps = 1 / cluster_tree .lambda_val [cluster_tree .child == parent ][0 ]
557- if parent_eps >= min_persistence :
630+ parent_death_eps = 1 / cluster_tree .lambda_val [cluster_tree .child == parent ][0 ]
631+ if parent_death_eps >= min_epsilon :
558632 return parent
559633 else :
560- return traverse_upwards (cluster_tree , min_persistence , root , parent )
634+ return traverse_upwards (cluster_tree , min_epsilon , root , parent )
561635
562636
563637@numba .njit ()
564- def segments_in_branch (cluster_tree , segment ):
638+ def segments_in_branch (parents , children , segment ):
565639 # only way to create a typed empty set
566- result = {np .intp (0 )}
640+ child_set = {np .int64 (0 )}
641+ result = {np .int64 (0 )}
567642 result .clear ()
568643 to_process = {segment }
569644
570645 while len (to_process ) > 0 :
571646 result |= to_process
572- to_process = set (cluster_tree .child [
573- in_set_parallel (cluster_tree .parent , to_process )
574- ])
647+
648+ child_set .clear ()
649+ for segment in to_process :
650+ idx = np .searchsorted (parents , segment )
651+ if idx >= len (parents ):
652+ continue
653+ child_set .add (children [idx ])
654+ child_set .add (children [idx + 1 ])
655+
656+ to_process .clear ()
657+ to_process |= child_set
575658
576659 return result
577660
578661
579- @numba .njit (parallel = True )
580- def in_set_parallel (values , targets ):
581- mask = np .empty (values .shape [0 ], dtype = numba .boolean )
582- for i in numba .prange (values .shape [0 ]):
583- mask [i ] = values [i ] in targets
584- return mask
585-
586-
587662@numba .njit (parallel = True )
588663def get_cluster_labelling_at_cut (linkage_tree , cut , min_cluster_size ):
589664
@@ -628,7 +703,7 @@ def get_cluster_label_vector(
628703 cluster_selection_epsilon ,
629704 n_samples ,
630705):
631- if len (tree .parent ) == 0 :
706+ if len (tree .parent ) == 0 or len ( clusters ) == 0 :
632707 return np .full (n_samples , - 1 , dtype = np .intp )
633708 root_cluster = tree .parent .min ()
634709 result = np .full (n_samples , - 1 , dtype = np .intp )
@@ -680,6 +755,14 @@ def max_lambdas(tree, clusters):
680755 return result
681756
682757
758+ @numba .njit ()
759+ def min_lambdas (cluster_tree , clusters ):
760+ return {
761+ c : cluster_tree .lambda_val [np .searchsorted (cluster_tree .child , c )]
762+ for c in clusters
763+ }
764+
765+
683766@numba .njit ()
684767def get_point_membership_strength_vector (tree , clusters , labels ):
685768 result = np .zeros (labels .shape [0 ], dtype = np .float32 )
0 commit comments