@@ -474,14 +474,14 @@ def cluster_tree_from_condensed_tree(condensed_tree):
474
474
condensed_tree .child_size [mask ])
475
475
476
476
477
- # @numba.njit()
477
+ @numba .njit ()
478
478
def unselect_below_node (node , cluster_tree , selected_clusters ):
479
479
for child in cluster_tree .child [cluster_tree .parent == node ]:
480
480
unselect_below_node (child , cluster_tree , selected_clusters )
481
481
selected_clusters [child ] = False
482
482
483
483
484
- # @numba.njit(fastmath=True)
484
+ @numba .njit (fastmath = True )
485
485
def eom_recursion (node , cluster_tree , node_scores , node_sizes , selected_clusters , max_cluster_size ):
486
486
current_score = node_scores [node ]
487
487
current_size = node_sizes [node ]
@@ -500,11 +500,14 @@ def eom_recursion(node, cluster_tree, node_scores, node_sizes, selected_clusters
500
500
return current_score
501
501
502
502
503
- # @numba.njit()
503
+ @numba .njit ()
504
504
def extract_eom_clusters (condensed_tree , cluster_tree , max_cluster_size = np .inf , allow_single_cluster = False ):
505
505
node_scores = score_condensed_tree_nodes (condensed_tree )
506
- node_sizes = {node : size for node , size in zip (cluster_tree .child , cluster_tree .child_size .astype (np .float32 ))}
507
- node_sizes [cluster_tree .parent .min ()] = np .float32 (cluster_tree .parent .min () - 1 )
506
+ if len (cluster_tree .parent ) > 0 :
507
+ node_sizes = {node : size for node , size in zip (cluster_tree .child , cluster_tree .child_size .astype (np .float32 ))}
508
+ node_sizes [cluster_tree .parent .min ()] = np .float32 (cluster_tree .parent .min () - 1 )
509
+ else :
510
+ node_sizes = {- 1 : np .float32 (0.0 )}
508
511
selected_clusters = {node : False for node in node_scores }
509
512
510
513
if len (cluster_tree .parent ) == 0 :
0 commit comments