@@ -287,14 +287,15 @@ def cluster_tree_from_condensed_tree_bcubed(condensed_tree, cluster_tree, label_
287287
288288
289289@numba .njit ()
290- def get_condensed_tree_clusters_bcubed (condensed_tree , cluster_tree = None , cluster_tree_bcubed = None , allow_virtual_nodes = False ):
290+ def get_condensed_tree_clusters_bcubed (condensed_tree , label_indices , cluster_tree = None , cluster_tree_bcubed = None , allow_virtual_nodes = False ):
291291
292292 cluster_elements = Dict .empty (
293293 key_type = int64 ,
294294 value_type = int64_list_type ,
295295 )
296296
297297 virtual_nodes = [0 for x in range (0 )]
298+ labeled_points = set (label_indices .keys ())
298299
299300 parents_set = set (list (condensed_tree .parent ))
300301 for i in range (len (condensed_tree .child ) - 1 , - 1 , - 1 ): # Traverse tree bottom up
@@ -304,39 +305,42 @@ def get_condensed_tree_clusters_bcubed(condensed_tree, cluster_tree=None, cluste
304305 if parent in cluster_elements :
305306 cluster_elements [parent ].extend (cluster_elements [child ])
306307 else :
307- cluster_elements [parent ] = List (cluster_elements [child ])
308+ cluster_labeled_points = list (set (cluster_elements [child ]) & labeled_points )
309+ cluster_elements [parent ] = List (cluster_labeled_points )
308310 elif parent in cluster_elements :
309- cluster_elements [parent ].append (child )
311+ if child in labeled_points :
312+ cluster_elements [parent ].append (child )
310313 else :
311314 cluster_elements [parent ] = List .empty_list (int64 )
312- cluster_elements [parent ].append (child )
315+ if child in labeled_points :
316+ cluster_elements [parent ].append (child )
313317
314318 if allow_virtual_nodes and (cluster_tree is not None ) and (cluster_tree_bcubed is not None ):
315- for i in list (set (cluster_tree_bcubed .child ).difference (set (cluster_tree .child ))):
316- virtual_nodes .append (i )
317- for node in virtual_nodes :
319+ for node in list (set (cluster_tree_bcubed .child ).difference (set (cluster_tree .child ))):
320+ virtual_nodes .append (node )
318321 cluster_elements [node ] = List .empty_list (int64 )
319322 cluster_elements [node ].append (node )
320-
323+
321324 return cluster_elements , np .array (virtual_nodes )
322325
323326
324327@numba .njit ()
325- def eom_recursion_bcubed (node , cluster_tree , stability_node_scores , bcubed_node_scores , selected_clusters ):
328+ def eom_recursion_bcubed (node , cluster_tree , stability_node_scores , bcubed_node_scores , selected_clusters , unselected_nodes ):
329+
326330 current_score_stability_bcubed = np .array ([stability_node_scores [node ], bcubed_node_scores [node ]], dtype = np .float32 )
327331
328332 children = cluster_tree .child [cluster_tree .parent == node ]
329333 child_score_total_stability_bcubed = np .array ([0.0 , 0.0 ], dtype = np .float32 )
330334
331335 for child_node in children :
332- child_score_total_stability_bcubed += eom_recursion_bcubed (child_node , cluster_tree , stability_node_scores , bcubed_node_scores , selected_clusters )
336+ child_score_total_stability_bcubed += eom_recursion_bcubed (child_node , cluster_tree , stability_node_scores , bcubed_node_scores , selected_clusters , unselected_nodes )
333337
334338 if child_score_total_stability_bcubed [1 ] > current_score_stability_bcubed [1 ]:
335339 return child_score_total_stability_bcubed
336340
337341 elif child_score_total_stability_bcubed [1 ] < current_score_stability_bcubed [1 ]:
338342 selected_clusters [node ] = True
339- unselect_below_node (node , cluster_tree , selected_clusters )
343+ unselect_below_node_bcubed (node , cluster_tree , selected_clusters , unselected_nodes )
340344 return current_score_stability_bcubed
341345
342346 # Stability scores used to resolve ties.
@@ -346,7 +350,7 @@ def eom_recursion_bcubed(node, cluster_tree, stability_node_scores, bcubed_node_
346350 return child_score_total_stability_bcubed
347351 else :
348352 selected_clusters [node ] = True
349- unselect_below_node (node , cluster_tree , selected_clusters )
353+ unselect_below_node_bcubed (node , cluster_tree , selected_clusters , unselected_nodes )
350354 return current_score_stability_bcubed
351355
352356
@@ -366,11 +370,9 @@ def score_condensed_tree_nodes_bcubed(cluster_elements, label_indices):
366370 total_num_of_labeled_points = sum (label_counts_values )
367371 bcubed = {0 : 0.0 for i in range (0 )}
368372
369- for cluster , elements in cluster_elements .items ():
373+ for cluster , cluster_labeled_points in cluster_elements .items ():
370374
371375 cluster_labeled_points_dict = {0 : 0 for i in range (0 )}
372-
373- cluster_labeled_points = list (set (elements ) & set (label_indices .keys ()))
374376 bcubed [cluster ] = 0.0
375377
376378 if len (cluster_labeled_points ) > 0 :
@@ -394,23 +396,31 @@ def score_condensed_tree_nodes_bcubed(cluster_elements, label_indices):
394396 bcubed [cluster ] += num_points * (2.0 / (1.0 / precision_point + 1.0 / recall_point ))
395397 return bcubed
396398
399+ @numba .njit ()
400+ def unselect_below_node_bcubed (node , cluster_tree , selected_clusters , unselected_nodes ):
401+
402+ for child in cluster_tree .child [cluster_tree .parent == node ]:
403+ if not unselected_nodes [child ]:
404+ unselect_below_node_bcubed (child , cluster_tree , selected_clusters , unselected_nodes )
405+ selected_clusters [child ] = False
406+ unselected_nodes [child ] = True
397407
398408@numba .njit ()
399409def extract_clusters_bcubed (condensed_tree , cluster_tree , label_indices , allow_virtual_nodes = False , allow_single_cluster = False ):
400410
401411 if allow_virtual_nodes :
402412
403413 cluster_tree_bcubed = cluster_tree_from_condensed_tree_bcubed (condensed_tree , cluster_tree , label_indices )
404- cluster_elements , virtual_nodes = get_condensed_tree_clusters_bcubed (condensed_tree , cluster_tree , cluster_tree_bcubed , allow_virtual_nodes )
414+ cluster_elements , virtual_nodes = get_condensed_tree_clusters_bcubed (condensed_tree , label_indices , cluster_tree , cluster_tree_bcubed , allow_virtual_nodes )
405415 stability_node_scores = score_condensed_tree_nodes (condensed_tree )
406416 for node in virtual_nodes :
407- stability_node_scores [node ] = 0.0
417+ stability_node_scores [node ] = np . float32 ( 0.0 )
408418 bcubed_node_scores = score_condensed_tree_nodes_bcubed (cluster_elements , label_indices )
409419
410420 else :
411421
412422 cluster_tree_bcubed = cluster_tree
413- cluster_elements , virtual_nodes = get_condensed_tree_clusters_bcubed (condensed_tree )
423+ cluster_elements , virtual_nodes = get_condensed_tree_clusters_bcubed (condensed_tree , label_indices )
414424 stability_node_scores = score_condensed_tree_nodes (condensed_tree )
415425 bcubed_node_scores = score_condensed_tree_nodes_bcubed (cluster_elements , label_indices )
416426
@@ -420,13 +430,14 @@ def extract_clusters_bcubed(condensed_tree, cluster_tree, label_indices, allow_v
420430 return np .zeros (0 , dtype = np .int64 )
421431
422432 cluster_tree_root = cluster_tree_bcubed .parent .min ()
433+ unselected_nodes = {node : False for node in bcubed_node_scores }
423434
424435 if allow_single_cluster :
425- eom_recursion_bcubed (cluster_tree_root , cluster_tree_bcubed , stability_node_scores , bcubed_node_scores , selected_clusters )
436+ eom_recursion_bcubed (cluster_tree_root , cluster_tree_bcubed , stability_node_scores , bcubed_node_scores , selected_clusters , unselected_nodes )
426437 elif len (bcubed_node_scores ) > 1 :
427438 root_children = cluster_tree_bcubed .child [cluster_tree_bcubed .parent == cluster_tree_root ]
428439 for child_node in root_children :
429- eom_recursion_bcubed (child_node , cluster_tree_bcubed , stability_node_scores , bcubed_node_scores , selected_clusters )
440+ eom_recursion_bcubed (child_node , cluster_tree_bcubed , stability_node_scores , bcubed_node_scores , selected_clusters , unselected_nodes )
430441
431442 return np .asarray ([node for node , selected in selected_clusters .items () if (selected and (node not in virtual_nodes ))])
432443
0 commit comments