@@ -161,7 +161,7 @@ def eliminate_branch(branch_node, parent_node, lambda_value, parents, children,
161
161
162
162
163
163
@numba .njit (fastmath = True )
164
- def condense_tree (hierarchy , min_cluster_size = 10 , sample_weights = None ):
164
+ def condense_tree (hierarchy , min_cluster_size = 10 , max_cluster_size = np . inf , sample_weights = None ):
165
165
root = 2 * hierarchy .shape [0 ]
166
166
num_points = hierarchy .shape [0 ] + 1
167
167
next_label = num_points + 1
@@ -223,7 +223,10 @@ def condense_tree(hierarchy, min_cluster_size=10, sample_weights=None):
223
223
hierarchy , num_points )
224
224
idx = eliminate_branch (right , parent_node , lambda_value , parents , children , lambdas , sizes , idx , ignore ,
225
225
hierarchy , num_points )
226
- # and finally if we actually have a legitimate cluster split, handle that correctly
226
+ # If both clusters are too large then relabel both
227
+ elif left_count > max_cluster_size and right_count > max_cluster_size :
228
+ relabel [left ] = parent_node
229
+ relabel [right ] = parent_node
227
230
else :
228
231
relabel [left ] = next_label
229
232
@@ -479,16 +482,17 @@ def unselect_below_node(node, cluster_tree, selected_clusters):
479
482
480
483
481
484
@numba .njit (fastmath = True )
482
- def eom_recursion (node , cluster_tree , node_scores , selected_clusters ):
485
+ def eom_recursion (node , cluster_tree , node_scores , node_sizes , selected_clusters , max_cluster_size ):
483
486
current_score = node_scores [node ]
487
+ current_size = node_sizes [node ]
484
488
485
489
children = cluster_tree .child [cluster_tree .parent == node ]
486
490
child_score_total = 0.0
487
491
488
492
for child_node in children :
489
- child_score_total += eom_recursion (child_node , cluster_tree , node_scores , selected_clusters )
493
+ child_score_total += eom_recursion (child_node , cluster_tree , node_scores , node_sizes , selected_clusters , max_cluster_size )
490
494
491
- if child_score_total > current_score :
495
+ if child_score_total > current_score or current_size > max_cluster_size :
492
496
return child_score_total
493
497
else :
494
498
selected_clusters [node ] = True
@@ -497,8 +501,13 @@ def eom_recursion(node, cluster_tree, node_scores, selected_clusters):
497
501
498
502
499
503
@numba .njit ()
500
- def extract_eom_clusters (condensed_tree , cluster_tree , allow_single_cluster = False ):
504
+ def extract_eom_clusters (condensed_tree , cluster_tree , max_cluster_size = np . inf , allow_single_cluster = False ):
501
505
node_scores = score_condensed_tree_nodes (condensed_tree )
506
+ if len (cluster_tree .parent ) > 0 :
507
+ node_sizes = {node : size for node , size in zip (cluster_tree .child , cluster_tree .child_size .astype (np .float32 ))}
508
+ node_sizes [cluster_tree .parent .min ()] = np .float32 (cluster_tree .parent .min () - 1 )
509
+ else :
510
+ node_sizes = {- 1 : np .float32 (0.0 )}
502
511
selected_clusters = {node : False for node in node_scores }
503
512
504
513
if len (cluster_tree .parent ) == 0 :
@@ -507,11 +516,11 @@ def extract_eom_clusters(condensed_tree, cluster_tree, allow_single_cluster=Fals
507
516
cluster_tree_root = cluster_tree .parent .min ()
508
517
509
518
if allow_single_cluster :
510
- eom_recursion (cluster_tree_root , cluster_tree , node_scores , selected_clusters )
519
+ eom_recursion (cluster_tree_root , cluster_tree , node_scores , node_sizes , selected_clusters , max_cluster_size )
511
520
elif len (node_scores ) > 1 :
512
521
root_children = cluster_tree .child [cluster_tree .parent == cluster_tree_root ]
513
522
for child_node in root_children :
514
- eom_recursion (child_node , cluster_tree , node_scores , selected_clusters )
523
+ eom_recursion (child_node , cluster_tree , node_scores , node_sizes , selected_clusters , max_cluster_size )
515
524
516
525
return np .asarray ([node for node , selected in selected_clusters .items () if selected ])
517
526
0 commit comments