@@ -49,8 +49,9 @@ def _get_leaves(condensed_tree):
4949 return _recurse_leaf_dfs (cluster_tree , root )
5050
5151class CondensedTree (object ):
52- def __init__ (self , condensed_tree_array ):
52+ def __init__ (self , condensed_tree_array , cluster_selection_method = 'eom' ):
5353 self ._raw_tree = condensed_tree_array
54+ self .cluster_selection_method = cluster_selection_method
5455
5556 def get_plot_data (self , leaf_separation = 1 , log_size = False ):
5657 """Generates data for use in plotting the 'icicle plot' or dendrogram
@@ -184,25 +185,28 @@ def get_plot_data(self, leaf_separation=1, log_size=False):
184185 }
185186
186187 def _select_clusters (self ):
187- stability = compute_stability (self ._raw_tree )
188- node_list = sorted (stability .keys (), reverse = True )[:- 1 ]
189- cluster_tree = self ._raw_tree [self ._raw_tree ['child_size' ] > 1 ]
190- is_cluster = {cluster : True for cluster in node_list }
191-
192- for node in node_list :
193- child_selection = (cluster_tree ['parent' ] == node )
194- subtree_stability = np .sum ([stability [child ] for
195- child in cluster_tree ['child' ][child_selection ]])
196-
197- if subtree_stability > stability [node ]:
198- is_cluster [node ] = False
199- stability [node ] = subtree_stability
200- else :
201- for sub_node in _bfs_from_cluster_tree (cluster_tree , node ):
202- if sub_node != node :
203- is_cluster [sub_node ] = False
188+ if self .cluster_selection_method == 'eom' :
189+ stability = compute_stability (self ._raw_tree )
190+ node_list = sorted (stability .keys (), reverse = True )[:- 1 ]
191+ cluster_tree = self ._raw_tree [self ._raw_tree ['child_size' ] > 1 ]
192+ is_cluster = {cluster : True for cluster in node_list }
193+
194+ for node in node_list :
195+ child_selection = (cluster_tree ['parent' ] == node )
196+ subtree_stability = np .sum ([stability [child ] for
197+ child in cluster_tree ['child' ][child_selection ]])
198+
199+ if subtree_stability > stability [node ]:
200+ is_cluster [node ] = False
201+ stability [node ] = subtree_stability
202+ else :
203+ for sub_node in _bfs_from_cluster_tree (cluster_tree , node ):
204+ if sub_node != node :
205+ is_cluster [sub_node ] = False
204206
205- return [cluster for cluster in is_cluster if is_cluster [cluster ]]
207+ return [cluster for cluster in is_cluster if is_cluster [cluster ]]
208+ else :
209+ return _get_leaves (self ._raw_tree )
206210
207211 def plot (self , leaf_separation = 1 , cmap = 'viridis' , select_clusters = False ,
208212 label_clusters = False , selection_palette = None ,
0 commit comments