@@ -287,14 +287,15 @@ def cluster_tree_from_condensed_tree_bcubed(condensed_tree, cluster_tree, label_
287
287
288
288
289
289
@numba .njit ()
290
- def get_condensed_tree_clusters_bcubed (condensed_tree , cluster_tree = None , cluster_tree_bcubed = None , allow_virtual_nodes = False ):
290
+ def get_condensed_tree_clusters_bcubed (condensed_tree , label_indices , cluster_tree = None , cluster_tree_bcubed = None , allow_virtual_nodes = False ):
291
291
292
292
cluster_elements = Dict .empty (
293
293
key_type = int64 ,
294
294
value_type = int64_list_type ,
295
295
)
296
296
297
297
virtual_nodes = [0 for x in range (0 )]
298
+ labeled_points = set (label_indices .keys ())
298
299
299
300
parents_set = set (list (condensed_tree .parent ))
300
301
for i in range (len (condensed_tree .child ) - 1 , - 1 , - 1 ): # Traverse tree bottom up
@@ -304,39 +305,42 @@ def get_condensed_tree_clusters_bcubed(condensed_tree, cluster_tree=None, cluste
304
305
if parent in cluster_elements :
305
306
cluster_elements [parent ].extend (cluster_elements [child ])
306
307
else :
307
- cluster_elements [parent ] = List (cluster_elements [child ])
308
+ cluster_labeled_points = list (set (cluster_elements [child ]) & labeled_points )
309
+ cluster_elements [parent ] = List (cluster_labeled_points )
308
310
elif parent in cluster_elements :
309
- cluster_elements [parent ].append (child )
311
+ if child in labeled_points :
312
+ cluster_elements [parent ].append (child )
310
313
else :
311
314
cluster_elements [parent ] = List .empty_list (int64 )
312
- cluster_elements [parent ].append (child )
315
+ if child in labeled_points :
316
+ cluster_elements [parent ].append (child )
313
317
314
318
if allow_virtual_nodes and (cluster_tree is not None ) and (cluster_tree_bcubed is not None ):
315
- for i in list (set (cluster_tree_bcubed .child ).difference (set (cluster_tree .child ))):
316
- virtual_nodes .append (i )
317
- for node in virtual_nodes :
319
+ for node in list (set (cluster_tree_bcubed .child ).difference (set (cluster_tree .child ))):
320
+ virtual_nodes .append (node )
318
321
cluster_elements [node ] = List .empty_list (int64 )
319
322
cluster_elements [node ].append (node )
320
-
323
+
321
324
return cluster_elements , np .array (virtual_nodes )
322
325
323
326
324
327
@numba .njit ()
325
- def eom_recursion_bcubed (node , cluster_tree , stability_node_scores , bcubed_node_scores , selected_clusters ):
328
+ def eom_recursion_bcubed (node , cluster_tree , stability_node_scores , bcubed_node_scores , selected_clusters , unselected_nodes ):
329
+
326
330
current_score_stability_bcubed = np .array ([stability_node_scores [node ], bcubed_node_scores [node ]], dtype = np .float32 )
327
331
328
332
children = cluster_tree .child [cluster_tree .parent == node ]
329
333
child_score_total_stability_bcubed = np .array ([0.0 , 0.0 ], dtype = np .float32 )
330
334
331
335
for child_node in children :
332
- child_score_total_stability_bcubed += eom_recursion_bcubed (child_node , cluster_tree , stability_node_scores , bcubed_node_scores , selected_clusters )
336
+ child_score_total_stability_bcubed += eom_recursion_bcubed (child_node , cluster_tree , stability_node_scores , bcubed_node_scores , selected_clusters , unselected_nodes )
333
337
334
338
if child_score_total_stability_bcubed [1 ] > current_score_stability_bcubed [1 ]:
335
339
return child_score_total_stability_bcubed
336
340
337
341
elif child_score_total_stability_bcubed [1 ] < current_score_stability_bcubed [1 ]:
338
342
selected_clusters [node ] = True
339
- unselect_below_node (node , cluster_tree , selected_clusters )
343
+ unselect_below_node_bcubed (node , cluster_tree , selected_clusters , unselected_nodes )
340
344
return current_score_stability_bcubed
341
345
342
346
# Stability scores used to resolve ties.
@@ -346,7 +350,7 @@ def eom_recursion_bcubed(node, cluster_tree, stability_node_scores, bcubed_node_
346
350
return child_score_total_stability_bcubed
347
351
else :
348
352
selected_clusters [node ] = True
349
- unselect_below_node (node , cluster_tree , selected_clusters )
353
+ unselect_below_node_bcubed (node , cluster_tree , selected_clusters , unselected_nodes )
350
354
return current_score_stability_bcubed
351
355
352
356
@@ -366,11 +370,9 @@ def score_condensed_tree_nodes_bcubed(cluster_elements, label_indices):
366
370
total_num_of_labeled_points = sum (label_counts_values )
367
371
bcubed = {0 : 0.0 for i in range (0 )}
368
372
369
- for cluster , elements in cluster_elements .items ():
373
+ for cluster , cluster_labeled_points in cluster_elements .items ():
370
374
371
375
cluster_labeled_points_dict = {0 : 0 for i in range (0 )}
372
-
373
- cluster_labeled_points = list (set (elements ) & set (label_indices .keys ()))
374
376
bcubed [cluster ] = 0.0
375
377
376
378
if len (cluster_labeled_points ) > 0 :
@@ -394,23 +396,31 @@ def score_condensed_tree_nodes_bcubed(cluster_elements, label_indices):
394
396
bcubed [cluster ] += num_points * (2.0 / (1.0 / precision_point + 1.0 / recall_point ))
395
397
return bcubed
396
398
399
+ @numba .njit ()
400
+ def unselect_below_node_bcubed (node , cluster_tree , selected_clusters , unselected_nodes ):
401
+
402
+ for child in cluster_tree .child [cluster_tree .parent == node ]:
403
+ if not unselected_nodes [child ]:
404
+ unselect_below_node_bcubed (child , cluster_tree , selected_clusters , unselected_nodes )
405
+ selected_clusters [child ] = False
406
+ unselected_nodes [child ] = True
397
407
398
408
@numba .njit ()
399
409
def extract_clusters_bcubed (condensed_tree , cluster_tree , label_indices , allow_virtual_nodes = False , allow_single_cluster = False ):
400
410
401
411
if allow_virtual_nodes :
402
412
403
413
cluster_tree_bcubed = cluster_tree_from_condensed_tree_bcubed (condensed_tree , cluster_tree , label_indices )
404
- cluster_elements , virtual_nodes = get_condensed_tree_clusters_bcubed (condensed_tree , cluster_tree , cluster_tree_bcubed , allow_virtual_nodes )
414
+ cluster_elements , virtual_nodes = get_condensed_tree_clusters_bcubed (condensed_tree , label_indices , cluster_tree , cluster_tree_bcubed , allow_virtual_nodes )
405
415
stability_node_scores = score_condensed_tree_nodes (condensed_tree )
406
416
for node in virtual_nodes :
407
- stability_node_scores [node ] = 0.0
417
+ stability_node_scores [node ] = np . float32 ( 0.0 )
408
418
bcubed_node_scores = score_condensed_tree_nodes_bcubed (cluster_elements , label_indices )
409
419
410
420
else :
411
421
412
422
cluster_tree_bcubed = cluster_tree
413
- cluster_elements , virtual_nodes = get_condensed_tree_clusters_bcubed (condensed_tree )
423
+ cluster_elements , virtual_nodes = get_condensed_tree_clusters_bcubed (condensed_tree , label_indices )
414
424
stability_node_scores = score_condensed_tree_nodes (condensed_tree )
415
425
bcubed_node_scores = score_condensed_tree_nodes_bcubed (cluster_elements , label_indices )
416
426
@@ -420,13 +430,14 @@ def extract_clusters_bcubed(condensed_tree, cluster_tree, label_indices, allow_v
420
430
return np .zeros (0 , dtype = np .int64 )
421
431
422
432
cluster_tree_root = cluster_tree_bcubed .parent .min ()
433
+ unselected_nodes = {node : False for node in bcubed_node_scores }
423
434
424
435
if allow_single_cluster :
425
- eom_recursion_bcubed (cluster_tree_root , cluster_tree_bcubed , stability_node_scores , bcubed_node_scores , selected_clusters )
436
+ eom_recursion_bcubed (cluster_tree_root , cluster_tree_bcubed , stability_node_scores , bcubed_node_scores , selected_clusters , unselected_nodes )
426
437
elif len (bcubed_node_scores ) > 1 :
427
438
root_children = cluster_tree_bcubed .child [cluster_tree_bcubed .parent == cluster_tree_root ]
428
439
for child_node in root_children :
429
- eom_recursion_bcubed (child_node , cluster_tree_bcubed , stability_node_scores , bcubed_node_scores , selected_clusters )
440
+ eom_recursion_bcubed (child_node , cluster_tree_bcubed , stability_node_scores , bcubed_node_scores , selected_clusters , unselected_nodes )
430
441
431
442
return np .asarray ([node for node , selected in selected_clusters .items () if (selected and (node not in virtual_nodes ))])
432
443
0 commit comments