Skip to content

Commit a3a5756

Browse files
committed
fix eom recursion; add fit_predict
1 parent ff0b3a5 commit a3a5756

File tree

2 files changed

+11
-12
lines changed

2 files changed

+11
-12
lines changed

fast_hdbscan/cluster_trees.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -450,21 +450,16 @@ def extract_clusters_bcubed(condensed_tree, cluster_tree, data_labels, allow_vir
450450

451451
@numba.njit()
452452
def score_condensed_tree_nodes(condensed_tree):
453-
result = {0: np.float32(0.0) for i in range(0)}
453+
root = condensed_tree.parent[0]
454+
result = {root: np.float32(0.0)}
454455

455456
for i in range(condensed_tree.parent.shape[0]):
456-
parent = condensed_tree.parent[i]
457-
if parent in result:
458-
result[parent] += condensed_tree.lambda_val[i] * condensed_tree.child_size[i]
459-
else:
460-
result[parent] = condensed_tree.lambda_val[i] * condensed_tree.child_size[i]
461-
462457
if condensed_tree.child_size[i] > 1:
463458
child = condensed_tree.child[i]
464-
if child in result:
465-
result[child] -= condensed_tree.lambda_val[i] * condensed_tree.child_size[i]
466-
else:
467-
result[child] = -condensed_tree.lambda_val[i] * condensed_tree.child_size[i]
459+
result[child] = -condensed_tree.lambda_val[i] * condensed_tree.child_size[i]
460+
461+
parent = condensed_tree.parent[i]
462+
result[parent] += condensed_tree.lambda_val[i] * condensed_tree.child_size[i]
468463

469464
return result
470465

@@ -493,7 +488,7 @@ def unselect_below_node(node, cluster_tree, selected_clusters):
493488

494489
@numba.njit(fastmath=True)
495490
def eom_recursion(node, cluster_tree, node_scores, node_sizes, selected_clusters, max_cluster_size):
496-
current_score = node_scores[node]
491+
current_score = max(node_scores[node], 0.0) # floating point errors can make score negative!
497492
current_size = node_sizes[node]
498493

499494
children = cluster_tree.child[cluster_tree.parent == node]

fast_hdbscan/hdbscan.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,10 @@ def fit(self, X, y=None, sample_weight=None, **fit_params):
392392

393393
return self
394394

395+
def fit_predict(self, X, y=None, sample_weight=None, **fit_params):
396+
self.fit(X, y, sample_weight, **fit_params)
397+
return self.labels_
398+
395399
def dbscan_clustering(self, epsilon):
396400
check_is_fitted(
397401
self,

0 commit comments

Comments
 (0)