Skip to content

Commit 543bc23

Browse files
committed
Initial pass at max cluster size
1 parent 3263d4f commit 543bc23

File tree

3 files changed

+28
-12
lines changed

3 files changed

+28
-12
lines changed

fast_hdbscan/cluster_trees.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def eliminate_branch(branch_node, parent_node, lambda_value, parents, children,
161161

162162

163163
@numba.njit(fastmath=True)
164-
def condense_tree(hierarchy, min_cluster_size=10, sample_weights=None):
164+
def condense_tree(hierarchy, min_cluster_size=10, max_cluster_size=np.inf, sample_weights=None):
165165
root = 2 * hierarchy.shape[0]
166166
num_points = hierarchy.shape[0] + 1
167167
next_label = num_points + 1
@@ -223,7 +223,10 @@ def condense_tree(hierarchy, min_cluster_size=10, sample_weights=None):
223223
hierarchy, num_points)
224224
idx = eliminate_branch(right, parent_node, lambda_value, parents, children, lambdas, sizes, idx, ignore,
225225
hierarchy, num_points)
226-
# and finally if we actually have a legitimate cluster split, handle that correctly
226+
# If both clusters are too large then relabel both
227+
elif left_count > max_cluster_size and right_count > max_cluster_size:
228+
relabel[left] = parent_node
229+
relabel[right] = parent_node
227230
else:
228231
relabel[left] = next_label
229232

@@ -471,34 +474,37 @@ def cluster_tree_from_condensed_tree(condensed_tree):
471474
condensed_tree.child_size[mask])
472475

473476

474-
@numba.njit()
477+
#@numba.njit()
475478
def unselect_below_node(node, cluster_tree, selected_clusters):
476479
for child in cluster_tree.child[cluster_tree.parent == node]:
477480
unselect_below_node(child, cluster_tree, selected_clusters)
478481
selected_clusters[child] = False
479482

480483

481-
@numba.njit(fastmath=True)
482-
def eom_recursion(node, cluster_tree, node_scores, selected_clusters):
484+
#@numba.njit(fastmath=True)
485+
def eom_recursion(node, cluster_tree, node_scores, node_sizes, selected_clusters, max_cluster_size):
483486
current_score = node_scores[node]
487+
current_size = node_sizes[node]
484488

485489
children = cluster_tree.child[cluster_tree.parent == node]
486490
child_score_total = 0.0
487491

488492
for child_node in children:
489-
child_score_total += eom_recursion(child_node, cluster_tree, node_scores, selected_clusters)
493+
child_score_total += eom_recursion(child_node, cluster_tree, node_scores, node_sizes, selected_clusters, max_cluster_size)
490494

491-
if child_score_total > current_score:
495+
if child_score_total > current_score or current_size > max_cluster_size:
492496
return child_score_total
493497
else:
494498
selected_clusters[node] = True
495499
unselect_below_node(node, cluster_tree, selected_clusters)
496500
return current_score
497501

498502

499-
@numba.njit()
500-
def extract_eom_clusters(condensed_tree, cluster_tree, allow_single_cluster=False):
503+
#@numba.njit()
504+
def extract_eom_clusters(condensed_tree, cluster_tree, max_cluster_size=np.inf, allow_single_cluster=False):
501505
node_scores = score_condensed_tree_nodes(condensed_tree)
506+
node_sizes = {node: size for node, size in zip(cluster_tree.child, cluster_tree.child_size.astype(np.float32))}
507+
node_sizes[cluster_tree.parent.min()] = np.float32(cluster_tree.parent.min() - 1)
502508
selected_clusters = {node: False for node in node_scores}
503509

504510
if len(cluster_tree.parent) == 0:
@@ -507,11 +513,11 @@ def extract_eom_clusters(condensed_tree, cluster_tree, allow_single_cluster=Fals
507513
cluster_tree_root = cluster_tree.parent.min()
508514

509515
if allow_single_cluster:
510-
eom_recursion(cluster_tree_root, cluster_tree, node_scores, selected_clusters)
516+
eom_recursion(cluster_tree_root, cluster_tree, node_scores, node_sizes, selected_clusters, max_cluster_size)
511517
elif len(node_scores) > 1:
512518
root_children = cluster_tree.child[cluster_tree.parent == cluster_tree_root]
513519
for child_node in root_children:
514-
eom_recursion(child_node, cluster_tree, node_scores, selected_clusters)
520+
eom_recursion(child_node, cluster_tree, node_scores, node_sizes, selected_clusters, max_cluster_size)
515521

516522
return np.asarray([node for node, selected in selected_clusters.items() if selected])
517523

fast_hdbscan/hdbscan.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def fast_hdbscan(
141141
min_samples=10,
142142
min_cluster_size=10,
143143
cluster_selection_method="eom",
144+
max_cluster_size=np.inf,
144145
allow_single_cluster=False,
145146
cluster_selection_epsilon=0.0,
146147
sample_weights=None,
@@ -214,7 +215,7 @@ def fast_hdbscan(
214215
raise ValueError(f"Invalid ss_algorithm {ss_algorithm}")
215216
else:
216217
selected_clusters = extract_eom_clusters(
217-
condensed_tree, cluster_tree, allow_single_cluster=allow_single_cluster
218+
condensed_tree, cluster_tree, max_cluster_size=max_cluster_size, allow_single_cluster=allow_single_cluster
218219
)
219220
elif cluster_selection_method == "leaf":
220221
selected_clusters = extract_leaves(
@@ -253,6 +254,7 @@ def __init__(
253254
min_samples=None,
254255
cluster_selection_method="eom",
255256
allow_single_cluster=False,
257+
max_cluster_size=np.inf,
256258
cluster_selection_epsilon=0.0,
257259
semi_supervised=False,
258260
ss_algorithm=None,
@@ -262,6 +264,7 @@ def __init__(
262264
self.min_samples = min_samples
263265
self.cluster_selection_method = cluster_selection_method
264266
self.allow_single_cluster = allow_single_cluster
267+
self.max_cluster_size = max_cluster_size
265268
self.cluster_selection_epsilon = cluster_selection_epsilon
266269
self.semi_supervised = semi_supervised
267270
self.ss_algorithm = ss_algorithm

fast_hdbscan/tests/test_hdbscan.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,13 @@ def test_fhdbscan_allow_single_cluster_with_epsilon():
167167
assert len(unique_labels) == 2
168168
assert counts[unique_labels == -1] == 2
169169

170+
def test_fhdbscan_max_cluster_size():
171+
model = HDBSCAN(max_cluster_size=30).fit(X)
172+
assert len(set(model.labels_)) >= 3
173+
for label in set(model.labels_):
174+
if label != -1:
175+
assert np.sum(model.labels_ == label) <= 30
176+
170177

171178
# Disable for now -- need to refactor to meet newer standards
172179
@pytest.mark.skip(reason="need to refactor to meet newer standards")

0 commit comments

Comments
 (0)