Skip to content

Commit e6e6b5e

Browse files
authored
Merge pull request TutteInstitute#29 from TutteInstitute/max_cluster_size
Add support for a max cluster size
2 parents 3263d4f + cd60746 commit e6e6b5e

File tree

3 files changed

+47
-22
lines changed

3 files changed

+47
-22
lines changed

fast_hdbscan/cluster_trees.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def eliminate_branch(branch_node, parent_node, lambda_value, parents, children,
161161

162162

163163
@numba.njit(fastmath=True)
164-
def condense_tree(hierarchy, min_cluster_size=10, sample_weights=None):
164+
def condense_tree(hierarchy, min_cluster_size=10, max_cluster_size=np.inf, sample_weights=None):
165165
root = 2 * hierarchy.shape[0]
166166
num_points = hierarchy.shape[0] + 1
167167
next_label = num_points + 1
@@ -223,7 +223,10 @@ def condense_tree(hierarchy, min_cluster_size=10, sample_weights=None):
223223
hierarchy, num_points)
224224
idx = eliminate_branch(right, parent_node, lambda_value, parents, children, lambdas, sizes, idx, ignore,
225225
hierarchy, num_points)
226-
# and finally if we actually have a legitimate cluster split, handle that correctly
226+
# If both clusters are too large then relabel both
227+
elif left_count > max_cluster_size and right_count > max_cluster_size:
228+
relabel[left] = parent_node
229+
relabel[right] = parent_node
227230
else:
228231
relabel[left] = next_label
229232

@@ -479,16 +482,17 @@ def unselect_below_node(node, cluster_tree, selected_clusters):
479482

480483

481484
@numba.njit(fastmath=True)
482-
def eom_recursion(node, cluster_tree, node_scores, selected_clusters):
485+
def eom_recursion(node, cluster_tree, node_scores, node_sizes, selected_clusters, max_cluster_size):
483486
current_score = node_scores[node]
487+
current_size = node_sizes[node]
484488

485489
children = cluster_tree.child[cluster_tree.parent == node]
486490
child_score_total = 0.0
487491

488492
for child_node in children:
489-
child_score_total += eom_recursion(child_node, cluster_tree, node_scores, selected_clusters)
493+
child_score_total += eom_recursion(child_node, cluster_tree, node_scores, node_sizes, selected_clusters, max_cluster_size)
490494

491-
if child_score_total > current_score:
495+
if child_score_total > current_score or current_size > max_cluster_size:
492496
return child_score_total
493497
else:
494498
selected_clusters[node] = True
@@ -497,8 +501,13 @@ def eom_recursion(node, cluster_tree, node_scores, selected_clusters):
497501

498502

499503
@numba.njit()
500-
def extract_eom_clusters(condensed_tree, cluster_tree, allow_single_cluster=False):
504+
def extract_eom_clusters(condensed_tree, cluster_tree, max_cluster_size=np.inf, allow_single_cluster=False):
501505
node_scores = score_condensed_tree_nodes(condensed_tree)
506+
if len(cluster_tree.parent) > 0:
507+
node_sizes = {node: size for node, size in zip(cluster_tree.child, cluster_tree.child_size.astype(np.float32))}
508+
node_sizes[cluster_tree.parent.min()] = np.float32(cluster_tree.parent.min() - 1)
509+
else:
510+
node_sizes = {-1: np.float32(0.0)}
502511
selected_clusters = {node: False for node in node_scores}
503512

504513
if len(cluster_tree.parent) == 0:
@@ -507,11 +516,11 @@ def extract_eom_clusters(condensed_tree, cluster_tree, allow_single_cluster=Fals
507516
cluster_tree_root = cluster_tree.parent.min()
508517

509518
if allow_single_cluster:
510-
eom_recursion(cluster_tree_root, cluster_tree, node_scores, selected_clusters)
519+
eom_recursion(cluster_tree_root, cluster_tree, node_scores, node_sizes, selected_clusters, max_cluster_size)
511520
elif len(node_scores) > 1:
512521
root_children = cluster_tree.child[cluster_tree.parent == cluster_tree_root]
513522
for child_node in root_children:
514-
eom_recursion(child_node, cluster_tree, node_scores, selected_clusters)
523+
eom_recursion(child_node, cluster_tree, node_scores, node_sizes, selected_clusters, max_cluster_size)
515524

516525
return np.asarray([node for node, selected in selected_clusters.items() if selected])
517526

fast_hdbscan/hdbscan.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def fast_hdbscan(
141141
min_samples=10,
142142
min_cluster_size=10,
143143
cluster_selection_method="eom",
144+
max_cluster_size=np.inf,
144145
allow_single_cluster=False,
145146
cluster_selection_epsilon=0.0,
146147
sample_weights=None,
@@ -214,7 +215,7 @@ def fast_hdbscan(
214215
raise ValueError(f"Invalid ss_algorithm {ss_algorithm}")
215216
else:
216217
selected_clusters = extract_eom_clusters(
217-
condensed_tree, cluster_tree, allow_single_cluster=allow_single_cluster
218+
condensed_tree, cluster_tree, max_cluster_size=max_cluster_size, allow_single_cluster=allow_single_cluster
218219
)
219220
elif cluster_selection_method == "leaf":
220221
selected_clusters = extract_leaves(
@@ -253,6 +254,7 @@ def __init__(
253254
min_samples=None,
254255
cluster_selection_method="eom",
255256
allow_single_cluster=False,
257+
max_cluster_size=np.inf,
256258
cluster_selection_epsilon=0.0,
257259
semi_supervised=False,
258260
ss_algorithm=None,
@@ -262,6 +264,7 @@ def __init__(
262264
self.min_samples = min_samples
263265
self.cluster_selection_method = cluster_selection_method
264266
self.allow_single_cluster = allow_single_cluster
267+
self.max_cluster_size = max_cluster_size
265268
self.cluster_selection_epsilon = cluster_selection_epsilon
266269
self.semi_supervised = semi_supervised
267270
self.ss_algorithm = ss_algorithm

fast_hdbscan/tests/test_hdbscan.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from sklearn.utils._testing import (
1111
assert_array_equal,
1212
assert_array_almost_equal,
13-
assert_raises,
1413
)
1514
from fast_hdbscan import (
1615
HDBSCAN,
@@ -132,23 +131,30 @@ def test_hdbscan_input_lists():
132131

133132

134133
def test_hdbscan_badargs():
135-
assert_raises(ValueError, fast_hdbscan, "fail")
136-
assert_raises(ValueError, fast_hdbscan, None)
137-
assert_raises(ValueError, fast_hdbscan, X, min_cluster_size="fail")
138-
assert_raises(ValueError, fast_hdbscan, X, min_samples="fail")
139-
assert_raises(ValueError, fast_hdbscan, X, min_samples=-1)
140-
assert_raises(ValueError, fast_hdbscan, X, cluster_selection_epsilon="fail")
141-
assert_raises(ValueError, fast_hdbscan, X, cluster_selection_epsilon=-1)
142-
assert_raises(ValueError, fast_hdbscan, X, cluster_selection_epsilon=-0.1)
143-
assert_raises(
144-
ValueError, fast_hdbscan, X, cluster_selection_method="fail"
145-
)
134+
with pytest.raises(ValueError):
135+
fast_hdbscan("fail")
136+
with pytest.raises(ValueError):
137+
fast_hdbscan(None)
138+
with pytest.raises(ValueError):
139+
fast_hdbscan(X, min_cluster_size="fail")
140+
with pytest.raises(ValueError):
141+
fast_hdbscan(X, min_samples="fail")
142+
with pytest.raises(ValueError):
143+
fast_hdbscan(X, min_samples=-1)
144+
with pytest.raises(ValueError):
145+
fast_hdbscan(X, cluster_selection_epsilon="fail")
146+
with pytest.raises(ValueError):
147+
fast_hdbscan(X, cluster_selection_epsilon=-1)
148+
with pytest.raises(ValueError):
149+
fast_hdbscan(X, cluster_selection_epsilon=-0.1)
150+
with pytest.raises(ValueError):
151+
fast_hdbscan(X, cluster_selection_method="fail")
146152

147153

148154
def test_fhdbscan_allow_single_cluster_with_epsilon():
149155
np.random.seed(0)
150156
no_structure = np.random.rand(150, 2)
151-
# without epsilon we should see 68 noise points and 9 labels
157+
# without epsilon we should see 68 noise points and 8 labels
152158
c = HDBSCAN(
153159
min_cluster_size=5,
154160
cluster_selection_epsilon=0.0,
@@ -167,6 +173,13 @@ def test_fhdbscan_allow_single_cluster_with_epsilon():
167173
assert len(unique_labels) == 2
168174
assert counts[unique_labels == -1] == 2
169175

176+
def test_fhdbscan_max_cluster_size():
177+
model = HDBSCAN(max_cluster_size=30).fit(X)
178+
assert len(set(model.labels_)) >= 3
179+
for label in set(model.labels_):
180+
if label != -1:
181+
assert np.sum(model.labels_ == label) <= 30
182+
170183

171184
# Disable for now -- need to refactor to meet newer standards
172185
@pytest.mark.skip(reason="need to refactor to meet newer standards")

0 commit comments

Comments
 (0)