@@ -161,7 +161,7 @@ def eliminate_branch(branch_node, parent_node, lambda_value, parents, children,
161161
162162
163163@numba .njit (fastmath = True )
164- def condense_tree (hierarchy , min_cluster_size = 10 , sample_weights = None ):
164+ def condense_tree (hierarchy , min_cluster_size = 10 , max_cluster_size = np . inf , sample_weights = None ):
165165 root = 2 * hierarchy .shape [0 ]
166166 num_points = hierarchy .shape [0 ] + 1
167167 next_label = num_points + 1
@@ -223,7 +223,10 @@ def condense_tree(hierarchy, min_cluster_size=10, sample_weights=None):
223223 hierarchy , num_points )
224224 idx = eliminate_branch (right , parent_node , lambda_value , parents , children , lambdas , sizes , idx , ignore ,
225225 hierarchy , num_points )
226- # and finally if we actually have a legitimate cluster split, handle that correctly
226+ # If both clusters are too large then relabel both
227+ elif left_count > max_cluster_size and right_count > max_cluster_size :
228+ relabel [left ] = parent_node
229+ relabel [right ] = parent_node
227230 else :
228231 relabel [left ] = next_label
229232
@@ -471,34 +474,37 @@ def cluster_tree_from_condensed_tree(condensed_tree):
471474 condensed_tree .child_size [mask ])
472475
473476
474- @numba .njit ()
477+ # @numba.njit()
475478def unselect_below_node (node , cluster_tree , selected_clusters ):
476479 for child in cluster_tree .child [cluster_tree .parent == node ]:
477480 unselect_below_node (child , cluster_tree , selected_clusters )
478481 selected_clusters [child ] = False
479482
480483
481- @numba .njit (fastmath = True )
482- def eom_recursion (node , cluster_tree , node_scores , selected_clusters ):
484+ # @numba.njit(fastmath=True)
485+ def eom_recursion (node , cluster_tree , node_scores , node_sizes , selected_clusters , max_cluster_size ):
483486 current_score = node_scores [node ]
487+ current_size = node_sizes [node ]
484488
485489 children = cluster_tree .child [cluster_tree .parent == node ]
486490 child_score_total = 0.0
487491
488492 for child_node in children :
489- child_score_total += eom_recursion (child_node , cluster_tree , node_scores , selected_clusters )
493+ child_score_total += eom_recursion (child_node , cluster_tree , node_scores , node_sizes , selected_clusters , max_cluster_size )
490494
491- if child_score_total > current_score :
495+ if child_score_total > current_score or current_size > max_cluster_size :
492496 return child_score_total
493497 else :
494498 selected_clusters [node ] = True
495499 unselect_below_node (node , cluster_tree , selected_clusters )
496500 return current_score
497501
498502
499- @numba .njit ()
500- def extract_eom_clusters (condensed_tree , cluster_tree , allow_single_cluster = False ):
503+ # @numba.njit()
504+ def extract_eom_clusters (condensed_tree , cluster_tree , max_cluster_size = np . inf , allow_single_cluster = False ):
501505 node_scores = score_condensed_tree_nodes (condensed_tree )
506+ node_sizes = {node : size for node , size in zip (cluster_tree .child , cluster_tree .child_size .astype (np .float32 ))}
507+ node_sizes [cluster_tree .parent .min ()] = np .float32 (cluster_tree .parent .min () - 1 )
502508 selected_clusters = {node : False for node in node_scores }
503509
504510 if len (cluster_tree .parent ) == 0 :
@@ -507,11 +513,11 @@ def extract_eom_clusters(condensed_tree, cluster_tree, allow_single_cluster=Fals
507513 cluster_tree_root = cluster_tree .parent .min ()
508514
509515 if allow_single_cluster :
510- eom_recursion (cluster_tree_root , cluster_tree , node_scores , selected_clusters )
516+ eom_recursion (cluster_tree_root , cluster_tree , node_scores , node_sizes , selected_clusters , max_cluster_size )
511517 elif len (node_scores ) > 1 :
512518 root_children = cluster_tree .child [cluster_tree .parent == cluster_tree_root ]
513519 for child_node in root_children :
514- eom_recursion (child_node , cluster_tree , node_scores , selected_clusters )
520+ eom_recursion (child_node , cluster_tree , node_scores , node_sizes , selected_clusters , max_cluster_size )
515521
516522 return np .asarray ([node for node , selected in selected_clusters .items () if selected ])
517523
0 commit comments