Skip to content

Commit d52fb8b

Browse files
committed
Update cluster selection in condensed tree so we can sync leaf selection with soft clustering.
1 parent a565763 commit d52fb8b

File tree

1 file changed

+23
-19
lines changed

1 file changed

+23
-19
lines changed

hdbscan/plots.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,9 @@ def _get_leaves(condensed_tree):
4949
return _recurse_leaf_dfs(cluster_tree, root)
5050

5151
class CondensedTree(object):
52-
def __init__(self, condensed_tree_array):
52+
def __init__(self, condensed_tree_array, cluster_selection_method='eom'):
5353
self._raw_tree = condensed_tree_array
54+
self.cluster_selection_method = cluster_selection_method
5455

5556
def get_plot_data(self, leaf_separation=1, log_size=False):
5657
"""Generates data for use in plotting the 'icicle plot' or dendrogram
@@ -184,25 +185,28 @@ def get_plot_data(self, leaf_separation=1, log_size=False):
184185
}
185186

186187
def _select_clusters(self):
187-
stability = compute_stability(self._raw_tree)
188-
node_list = sorted(stability.keys(), reverse=True)[:-1]
189-
cluster_tree = self._raw_tree[self._raw_tree['child_size'] > 1]
190-
is_cluster = {cluster: True for cluster in node_list}
191-
192-
for node in node_list:
193-
child_selection = (cluster_tree['parent'] == node)
194-
subtree_stability = np.sum([stability[child] for
195-
child in cluster_tree['child'][child_selection]])
196-
197-
if subtree_stability > stability[node]:
198-
is_cluster[node] = False
199-
stability[node] = subtree_stability
200-
else:
201-
for sub_node in _bfs_from_cluster_tree(cluster_tree, node):
202-
if sub_node != node:
203-
is_cluster[sub_node] = False
188+
if self.cluster_selection_method == 'eom':
189+
stability = compute_stability(self._raw_tree)
190+
node_list = sorted(stability.keys(), reverse=True)[:-1]
191+
cluster_tree = self._raw_tree[self._raw_tree['child_size'] > 1]
192+
is_cluster = {cluster: True for cluster in node_list}
193+
194+
for node in node_list:
195+
child_selection = (cluster_tree['parent'] == node)
196+
subtree_stability = np.sum([stability[child] for
197+
child in cluster_tree['child'][child_selection]])
198+
199+
if subtree_stability > stability[node]:
200+
is_cluster[node] = False
201+
stability[node] = subtree_stability
202+
else:
203+
for sub_node in _bfs_from_cluster_tree(cluster_tree, node):
204+
if sub_node != node:
205+
is_cluster[sub_node] = False
204206

205-
return [cluster for cluster in is_cluster if is_cluster[cluster]]
207+
return [cluster for cluster in is_cluster if is_cluster[cluster]]
208+
else:
209+
return _get_leaves(self._raw_tree)
206210

207211
def plot(self, leaf_separation=1, cmap='viridis', select_clusters=False,
208212
label_clusters=False, selection_palette=None,

0 commit comments

Comments
 (0)