@@ -161,7 +161,7 @@ def eliminate_branch(branch_node, parent_node, lambda_value, parents, children,
161161
162162
163163@numba .njit (fastmath = True )
164- def condense_tree (hierarchy , min_cluster_size = 10 , sample_weights = None ):
164+ def condense_tree (hierarchy , min_cluster_size = 10 , max_cluster_size = np . inf , sample_weights = None ):
165165 root = 2 * hierarchy .shape [0 ]
166166 num_points = hierarchy .shape [0 ] + 1
167167 next_label = num_points + 1
@@ -223,7 +223,10 @@ def condense_tree(hierarchy, min_cluster_size=10, sample_weights=None):
223223 hierarchy , num_points )
224224 idx = eliminate_branch (right , parent_node , lambda_value , parents , children , lambdas , sizes , idx , ignore ,
225225 hierarchy , num_points )
226- # and finally if we actually have a legitimate cluster split, handle that correctly
226+ # If both clusters are too large then relabel both
227+ elif left_count > max_cluster_size and right_count > max_cluster_size :
228+ relabel [left ] = parent_node
229+ relabel [right ] = parent_node
227230 else :
228231 relabel [left ] = next_label
229232
@@ -479,16 +482,17 @@ def unselect_below_node(node, cluster_tree, selected_clusters):
479482
480483
481484@numba .njit (fastmath = True )
482- def eom_recursion (node , cluster_tree , node_scores , selected_clusters ):
485+ def eom_recursion (node , cluster_tree , node_scores , node_sizes , selected_clusters , max_cluster_size ):
483486 current_score = node_scores [node ]
487+ current_size = node_sizes [node ]
484488
485489 children = cluster_tree .child [cluster_tree .parent == node ]
486490 child_score_total = 0.0
487491
488492 for child_node in children :
489- child_score_total += eom_recursion (child_node , cluster_tree , node_scores , selected_clusters )
493+ child_score_total += eom_recursion (child_node , cluster_tree , node_scores , node_sizes , selected_clusters , max_cluster_size )
490494
491- if child_score_total > current_score :
495+ if child_score_total > current_score or current_size > max_cluster_size :
492496 return child_score_total
493497 else :
494498 selected_clusters [node ] = True
@@ -497,8 +501,13 @@ def eom_recursion(node, cluster_tree, node_scores, selected_clusters):
497501
498502
499503@numba .njit ()
500- def extract_eom_clusters (condensed_tree , cluster_tree , allow_single_cluster = False ):
504+ def extract_eom_clusters (condensed_tree , cluster_tree , max_cluster_size = np . inf , allow_single_cluster = False ):
501505 node_scores = score_condensed_tree_nodes (condensed_tree )
506+ if len (cluster_tree .parent ) > 0 :
507+ node_sizes = {node : size for node , size in zip (cluster_tree .child , cluster_tree .child_size .astype (np .float32 ))}
508+ node_sizes [cluster_tree .parent .min ()] = np .float32 (cluster_tree .parent .min () - 1 )
509+ else :
510+ node_sizes = {- 1 : np .float32 (0.0 )}
502511 selected_clusters = {node : False for node in node_scores }
503512
504513 if len (cluster_tree .parent ) == 0 :
@@ -507,11 +516,11 @@ def extract_eom_clusters(condensed_tree, cluster_tree, allow_single_cluster=Fals
507516 cluster_tree_root = cluster_tree .parent .min ()
508517
509518 if allow_single_cluster :
510- eom_recursion (cluster_tree_root , cluster_tree , node_scores , selected_clusters )
519+ eom_recursion (cluster_tree_root , cluster_tree , node_scores , node_sizes , selected_clusters , max_cluster_size )
511520 elif len (node_scores ) > 1 :
512521 root_children = cluster_tree .child [cluster_tree .parent == cluster_tree_root ]
513522 for child_node in root_children :
514- eom_recursion (child_node , cluster_tree , node_scores , selected_clusters )
523+ eom_recursion (child_node , cluster_tree , node_scores , node_sizes , selected_clusters , max_cluster_size )
515524
516525 return np .asarray ([node for node , selected in selected_clusters .items () if selected ])
517526
0 commit comments