@@ -474,14 +474,14 @@ def cluster_tree_from_condensed_tree(condensed_tree):
474474 condensed_tree .child_size [mask ])
475475
476476
477- # @numba.njit()
477+ @numba .njit ()
478478def unselect_below_node (node , cluster_tree , selected_clusters ):
479479 for child in cluster_tree .child [cluster_tree .parent == node ]:
480480 unselect_below_node (child , cluster_tree , selected_clusters )
481481 selected_clusters [child ] = False
482482
483483
484- # @numba.njit(fastmath=True)
484+ @numba .njit (fastmath = True )
485485def eom_recursion (node , cluster_tree , node_scores , node_sizes , selected_clusters , max_cluster_size ):
486486 current_score = node_scores [node ]
487487 current_size = node_sizes [node ]
@@ -500,11 +500,14 @@ def eom_recursion(node, cluster_tree, node_scores, node_sizes, selected_clusters
500500 return current_score
501501
502502
503- # @numba.njit()
503+ @numba .njit ()
504504def extract_eom_clusters (condensed_tree , cluster_tree , max_cluster_size = np .inf , allow_single_cluster = False ):
505505 node_scores = score_condensed_tree_nodes (condensed_tree )
506- node_sizes = {node : size for node , size in zip (cluster_tree .child , cluster_tree .child_size .astype (np .float32 ))}
507- node_sizes [cluster_tree .parent .min ()] = np .float32 (cluster_tree .parent .min () - 1 )
506+ if len (cluster_tree .parent ) > 0 :
507+ node_sizes = {node : size for node , size in zip (cluster_tree .child , cluster_tree .child_size .astype (np .float32 ))}
508+ node_sizes [cluster_tree .parent .min ()] = np .float32 (cluster_tree .parent .min () - 1 )
509+ else :
510+ node_sizes = {- 1 : np .float32 (0.0 )}
508511 selected_clusters = {node : False for node in node_scores }
509512
510513 if len (cluster_tree .parent ) == 0 :
0 commit comments