@@ -161,7 +161,7 @@ def eliminate_branch(branch_node, parent_node, lambda_value, parents, children,
161
161
162
162
163
163
@numba .njit (fastmath = True )
164
- def condense_tree (hierarchy , min_cluster_size = 10 , sample_weights = None ):
164
+ def condense_tree (hierarchy , min_cluster_size = 10 , max_cluster_size = np . inf , sample_weights = None ):
165
165
root = 2 * hierarchy .shape [0 ]
166
166
num_points = hierarchy .shape [0 ] + 1
167
167
next_label = num_points + 1
@@ -223,7 +223,10 @@ def condense_tree(hierarchy, min_cluster_size=10, sample_weights=None):
223
223
hierarchy , num_points )
224
224
idx = eliminate_branch (right , parent_node , lambda_value , parents , children , lambdas , sizes , idx , ignore ,
225
225
hierarchy , num_points )
226
- # and finally if we actually have a legitimate cluster split, handle that correctly
226
+ # If both clusters are too large then relabel both
227
+ elif left_count > max_cluster_size and right_count > max_cluster_size :
228
+ relabel [left ] = parent_node
229
+ relabel [right ] = parent_node
227
230
else :
228
231
relabel [left ] = next_label
229
232
@@ -471,34 +474,37 @@ def cluster_tree_from_condensed_tree(condensed_tree):
471
474
condensed_tree .child_size [mask ])
472
475
473
476
474
- @numba .njit ()
477
+ # @numba.njit()
475
478
def unselect_below_node (node , cluster_tree , selected_clusters ):
476
479
for child in cluster_tree .child [cluster_tree .parent == node ]:
477
480
unselect_below_node (child , cluster_tree , selected_clusters )
478
481
selected_clusters [child ] = False
479
482
480
483
481
- @numba .njit (fastmath = True )
482
- def eom_recursion (node , cluster_tree , node_scores , selected_clusters ):
484
+ # @numba.njit(fastmath=True)
485
+ def eom_recursion (node , cluster_tree , node_scores , node_sizes , selected_clusters , max_cluster_size ):
483
486
current_score = node_scores [node ]
487
+ current_size = node_sizes [node ]
484
488
485
489
children = cluster_tree .child [cluster_tree .parent == node ]
486
490
child_score_total = 0.0
487
491
488
492
for child_node in children :
489
- child_score_total += eom_recursion (child_node , cluster_tree , node_scores , selected_clusters )
493
+ child_score_total += eom_recursion (child_node , cluster_tree , node_scores , node_sizes , selected_clusters , max_cluster_size )
490
494
491
- if child_score_total > current_score :
495
+ if child_score_total > current_score or current_size > max_cluster_size :
492
496
return child_score_total
493
497
else :
494
498
selected_clusters [node ] = True
495
499
unselect_below_node (node , cluster_tree , selected_clusters )
496
500
return current_score
497
501
498
502
499
- @numba .njit ()
500
- def extract_eom_clusters (condensed_tree , cluster_tree , allow_single_cluster = False ):
503
+ # @numba.njit()
504
+ def extract_eom_clusters (condensed_tree , cluster_tree , max_cluster_size = np . inf , allow_single_cluster = False ):
501
505
node_scores = score_condensed_tree_nodes (condensed_tree )
506
+ node_sizes = {node : size for node , size in zip (cluster_tree .child , cluster_tree .child_size .astype (np .float32 ))}
507
+ node_sizes [cluster_tree .parent .min ()] = np .float32 (cluster_tree .parent .min () - 1 )
502
508
selected_clusters = {node : False for node in node_scores }
503
509
504
510
if len (cluster_tree .parent ) == 0 :
@@ -507,11 +513,11 @@ def extract_eom_clusters(condensed_tree, cluster_tree, allow_single_cluster=Fals
507
513
cluster_tree_root = cluster_tree .parent .min ()
508
514
509
515
if allow_single_cluster :
510
- eom_recursion (cluster_tree_root , cluster_tree , node_scores , selected_clusters )
516
+ eom_recursion (cluster_tree_root , cluster_tree , node_scores , node_sizes , selected_clusters , max_cluster_size )
511
517
elif len (node_scores ) > 1 :
512
518
root_children = cluster_tree .child [cluster_tree .parent == cluster_tree_root ]
513
519
for child_node in root_children :
514
- eom_recursion (child_node , cluster_tree , node_scores , selected_clusters )
520
+ eom_recursion (child_node , cluster_tree , node_scores , node_sizes , selected_clusters , max_cluster_size )
515
521
516
522
return np .asarray ([node for node , selected in selected_clusters .items () if selected ])
517
523
0 commit comments