Skip to content

Commit 05a4a62

Browse files
committed
Add changes to make semi-supervised code more efficient.
1 parent 40435f7 commit 05a4a62

File tree

2 files changed

+32
-21
lines changed

2 files changed

+32
-21
lines changed

fast_hdbscan/cluster_trees.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -287,14 +287,15 @@ def cluster_tree_from_condensed_tree_bcubed(condensed_tree, cluster_tree, label_
287287

288288

289289
@numba.njit()
290-
def get_condensed_tree_clusters_bcubed(condensed_tree, cluster_tree=None, cluster_tree_bcubed=None, allow_virtual_nodes=False):
290+
def get_condensed_tree_clusters_bcubed(condensed_tree, label_indices, cluster_tree=None, cluster_tree_bcubed=None, allow_virtual_nodes=False):
291291

292292
cluster_elements = Dict.empty(
293293
key_type=int64,
294294
value_type=int64_list_type,
295295
)
296296

297297
virtual_nodes = [0 for x in range(0)]
298+
labeled_points = set(label_indices.keys())
298299

299300
parents_set = set(list(condensed_tree.parent))
300301
for i in range(len(condensed_tree.child) - 1, -1, -1): # Traverse tree bottom up
@@ -304,39 +305,42 @@ def get_condensed_tree_clusters_bcubed(condensed_tree, cluster_tree=None, cluste
304305
if parent in cluster_elements:
305306
cluster_elements[parent].extend(cluster_elements[child])
306307
else:
307-
cluster_elements[parent] = List(cluster_elements[child])
308+
cluster_labeled_points = list(set(cluster_elements[child]) & labeled_points)
309+
cluster_elements[parent] = List(cluster_labeled_points)
308310
elif parent in cluster_elements:
309-
cluster_elements[parent].append(child)
311+
if child in labeled_points:
312+
cluster_elements[parent].append(child)
310313
else:
311314
cluster_elements[parent] = List.empty_list(int64)
312-
cluster_elements[parent].append(child)
315+
if child in labeled_points:
316+
cluster_elements[parent].append(child)
313317

314318
if allow_virtual_nodes and (cluster_tree is not None) and (cluster_tree_bcubed is not None):
315-
for i in list(set(cluster_tree_bcubed.child).difference(set(cluster_tree.child))):
316-
virtual_nodes.append(i)
317-
for node in virtual_nodes:
319+
for node in list(set(cluster_tree_bcubed.child).difference(set(cluster_tree.child))):
320+
virtual_nodes.append(node)
318321
cluster_elements[node] = List.empty_list(int64)
319322
cluster_elements[node].append(node)
320-
323+
321324
return cluster_elements, np.array(virtual_nodes)
322325

323326

324327
@numba.njit()
325-
def eom_recursion_bcubed(node, cluster_tree, stability_node_scores, bcubed_node_scores, selected_clusters):
328+
def eom_recursion_bcubed(node, cluster_tree, stability_node_scores, bcubed_node_scores, selected_clusters, unselected_nodes):
329+
326330
current_score_stability_bcubed = np.array([stability_node_scores[node], bcubed_node_scores[node]], dtype=np.float32)
327331

328332
children = cluster_tree.child[cluster_tree.parent == node]
329333
child_score_total_stability_bcubed = np.array([0.0, 0.0], dtype=np.float32)
330334

331335
for child_node in children:
332-
child_score_total_stability_bcubed += eom_recursion_bcubed(child_node, cluster_tree, stability_node_scores, bcubed_node_scores, selected_clusters)
336+
child_score_total_stability_bcubed += eom_recursion_bcubed(child_node, cluster_tree, stability_node_scores, bcubed_node_scores, selected_clusters, unselected_nodes)
333337

334338
if child_score_total_stability_bcubed[1] > current_score_stability_bcubed[1]:
335339
return child_score_total_stability_bcubed
336340

337341
elif child_score_total_stability_bcubed[1] < current_score_stability_bcubed[1]:
338342
selected_clusters[node] = True
339-
unselect_below_node(node, cluster_tree, selected_clusters)
343+
unselect_below_node_bcubed(node, cluster_tree, selected_clusters, unselected_nodes)
340344
return current_score_stability_bcubed
341345

342346
# Stability scores used to resolve ties.
@@ -346,7 +350,7 @@ def eom_recursion_bcubed(node, cluster_tree, stability_node_scores, bcubed_node_
346350
return child_score_total_stability_bcubed
347351
else:
348352
selected_clusters[node] = True
349-
unselect_below_node(node, cluster_tree, selected_clusters)
353+
unselect_below_node_bcubed(node, cluster_tree, selected_clusters, unselected_nodes)
350354
return current_score_stability_bcubed
351355

352356

@@ -366,11 +370,9 @@ def score_condensed_tree_nodes_bcubed(cluster_elements, label_indices):
366370
total_num_of_labeled_points = sum(label_counts_values)
367371
bcubed = {0: 0.0 for i in range(0)}
368372

369-
for cluster, elements in cluster_elements.items():
373+
for cluster, cluster_labeled_points in cluster_elements.items():
370374

371375
cluster_labeled_points_dict = {0: 0 for i in range(0)}
372-
373-
cluster_labeled_points = list(set(elements) & set(label_indices.keys()))
374376
bcubed[cluster] = 0.0
375377

376378
if len(cluster_labeled_points) > 0:
@@ -394,23 +396,31 @@ def score_condensed_tree_nodes_bcubed(cluster_elements, label_indices):
394396
bcubed[cluster] += num_points*(2.0/(1.0/precision_point + 1.0/recall_point))
395397
return bcubed
396398

399+
@numba.njit()
400+
def unselect_below_node_bcubed(node, cluster_tree, selected_clusters, unselected_nodes):
401+
402+
for child in cluster_tree.child[cluster_tree.parent == node]:
403+
if not unselected_nodes[child]:
404+
unselect_below_node_bcubed(child, cluster_tree, selected_clusters, unselected_nodes)
405+
selected_clusters[child] = False
406+
unselected_nodes[child] = True
397407

398408
@numba.njit()
399409
def extract_clusters_bcubed(condensed_tree, cluster_tree, label_indices, allow_virtual_nodes=False, allow_single_cluster=False):
400410

401411
if allow_virtual_nodes:
402412

403413
cluster_tree_bcubed = cluster_tree_from_condensed_tree_bcubed(condensed_tree, cluster_tree, label_indices)
404-
cluster_elements, virtual_nodes = get_condensed_tree_clusters_bcubed(condensed_tree, cluster_tree, cluster_tree_bcubed, allow_virtual_nodes)
414+
cluster_elements, virtual_nodes = get_condensed_tree_clusters_bcubed(condensed_tree, label_indices, cluster_tree, cluster_tree_bcubed, allow_virtual_nodes)
405415
stability_node_scores = score_condensed_tree_nodes(condensed_tree)
406416
for node in virtual_nodes:
407-
stability_node_scores[node] = 0.0
417+
stability_node_scores[node] = np.float32(0.0)
408418
bcubed_node_scores = score_condensed_tree_nodes_bcubed(cluster_elements, label_indices)
409419

410420
else:
411421

412422
cluster_tree_bcubed = cluster_tree
413-
cluster_elements, virtual_nodes = get_condensed_tree_clusters_bcubed(condensed_tree)
423+
cluster_elements, virtual_nodes = get_condensed_tree_clusters_bcubed(condensed_tree, label_indices)
414424
stability_node_scores = score_condensed_tree_nodes(condensed_tree)
415425
bcubed_node_scores = score_condensed_tree_nodes_bcubed(cluster_elements, label_indices)
416426

@@ -420,13 +430,14 @@ def extract_clusters_bcubed(condensed_tree, cluster_tree, label_indices, allow_v
420430
return np.zeros(0, dtype=np.int64)
421431

422432
cluster_tree_root = cluster_tree_bcubed.parent.min()
433+
unselected_nodes = {node: False for node in bcubed_node_scores}
423434

424435
if allow_single_cluster:
425-
eom_recursion_bcubed(cluster_tree_root, cluster_tree_bcubed, stability_node_scores, bcubed_node_scores, selected_clusters)
436+
eom_recursion_bcubed(cluster_tree_root, cluster_tree_bcubed, stability_node_scores, bcubed_node_scores, selected_clusters, unselected_nodes)
426437
elif len(bcubed_node_scores) > 1:
427438
root_children = cluster_tree_bcubed.child[cluster_tree_bcubed.parent == cluster_tree_root]
428439
for child_node in root_children:
429-
eom_recursion_bcubed(child_node, cluster_tree_bcubed, stability_node_scores, bcubed_node_scores, selected_clusters)
440+
eom_recursion_bcubed(child_node, cluster_tree_bcubed, stability_node_scores, bcubed_node_scores, selected_clusters, unselected_nodes)
430441

431442
return np.asarray([node for node, selected in selected_clusters.items() if (selected and (node not in virtual_nodes))])
432443

fast_hdbscan/hdbscan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def fast_hdbscan(
202202
allow_virtual_nodes=True,
203203
allow_single_cluster=allow_single_cluster,
204204
)
205-
elif ss_algorithm == "bc_without_vn":
205+
elif ss_algorithm == "bc_simple":
206206
selected_clusters = extract_clusters_bcubed(
207207
condensed_tree,
208208
cluster_tree,

0 commit comments

Comments
 (0)