@@ -450,21 +450,16 @@ def extract_clusters_bcubed(condensed_tree, cluster_tree, data_labels, allow_vir
450450
451451@numba .njit ()
452452def score_condensed_tree_nodes (condensed_tree ):
453- result = {0 : np .float32 (0.0 ) for i in range (0 )}
453+ root = condensed_tree .parent [0 ]
454+ result = {root : np .float32 (0.0 )}
454455
455456 for i in range (condensed_tree .parent .shape [0 ]):
456- parent = condensed_tree .parent [i ]
457- if parent in result :
458- result [parent ] += condensed_tree .lambda_val [i ] * condensed_tree .child_size [i ]
459- else :
460- result [parent ] = condensed_tree .lambda_val [i ] * condensed_tree .child_size [i ]
461-
462457 if condensed_tree .child_size [i ] > 1 :
463458 child = condensed_tree .child [i ]
464- if child in result :
465- result [ child ] -= condensed_tree . lambda_val [ i ] * condensed_tree . child_size [ i ]
466- else :
467- result [child ] = - condensed_tree .lambda_val [i ] * condensed_tree .child_size [i ]
459+ result [ child ] = - condensed_tree . lambda_val [ i ] * condensed_tree . child_size [ i ]
460+
461+ parent = condensed_tree . parent [ i ]
462+ result [parent ] += condensed_tree .lambda_val [i ] * condensed_tree .child_size [i ]
468463
469464 return result
470465
@@ -493,7 +488,7 @@ def unselect_below_node(node, cluster_tree, selected_clusters):
493488
494489@numba .njit (fastmath = True )
495490def eom_recursion (node , cluster_tree , node_scores , node_sizes , selected_clusters , max_cluster_size ):
496- current_score = node_scores [node ]
491+ current_score = max ( node_scores [node ], 0.0 ) # floating point errors can make score negative!
497492 current_size = node_sizes [node ]
498493
499494 children = cluster_tree .child [cluster_tree .parent == node ]
0 commit comments