@@ -221,7 +221,7 @@ def _select_clusters(self):
221221
222222 return [cluster for cluster in is_cluster if is_cluster [cluster ]]
223223
224- def plot (self , leaf_separation = 1 , cmap = 'Blues ' , select_clusters = False ,
224+ def plot (self , leaf_separation = 1 , cmap = 'viridis ' , select_clusters = False ,
225225 label_clusters = False , selection_palette = None ,
226226 axis = None , colorbar = True , log_size = False ):
227227 """Use matplotlib to plot an 'icicle plot' dendrogram of the condensed tree.
@@ -240,7 +240,7 @@ def plot(self, leaf_separation=1, cmap='Blues', select_clusters=False,
240240
241241 cmap : string or matplotlib colormap, optional
242242 The matplotlib colormap to use to color the cluster bars.
243- (default Blues )
243+ (default viridis )
244244
245245 select_clusters : boolean, optional
246246 Whether to draw ovals highlighting which cluster
@@ -431,7 +431,8 @@ class SingleLinkageTree(object):
431431 def __init__ (self , linkage ):
432432 self ._linkage = linkage
433433
434- def plot (self , axis = None , truncate_mode = None , p = 0 , vary_line_width = True ):
434+ def plot (self , axis = None , truncate_mode = None , p = 0 , vary_line_width = True ,
435+ cmap = 'none' , colorbar = False ):
435436 """Plot a dendrogram of the single linkage tree.
436437
437438 Parameters
@@ -462,6 +463,14 @@ def plot(self, axis=None, truncate_mode=None, p=0, vary_line_width=True):
462463 Draw downward branches of the dendrogram with line thickness that
463464 varies depending on the size of the cluster.
464465
466+ cmap : string or matplotlib colormap, optional
467+ The matplotlib colormap to use to color the cluster bars.
468+ (default 'none')
469+
470+ colorbar : boolean, optional
471+ Whether to draw a matplotlib colorbar displaying the range
472+ of cluster sizes as per the colormap. (default True)
473+
465474 Returns
466475 -------
467476 axis : matplotlib axis
@@ -487,6 +496,12 @@ def plot(self, axis=None, truncate_mode=None, p=0, vary_line_width=True):
487496 else :
488497 linewidths = [(1.0 , 1.0 )] * len (Y )
489498
499+ if cmap != 'none' :
500+ color_array = np .log2 (np .array (linewidths ).flatten ())
501+ sm = plt .cm .ScalarMappable (cmap = cmap ,
502+ norm = plt .Normalize (0 , color_array .max ()))
503+ sm .set_array (color_array )
504+
490505 for x , y , lw in zip (X , Y , linewidths ):
491506 left_x = x [:2 ]
492507 right_x = x [2 :]
@@ -495,13 +510,28 @@ def plot(self, axis=None, truncate_mode=None, p=0, vary_line_width=True):
495510 horizontal_x = x [1 :3 ]
496511 horizontal_y = y [1 :3 ]
497512
498- axis .plot (left_x , left_y , color = 'k' , linewidth = np .log2 (1 + lw [0 ]),
499- solid_joinstyle = 'miter' , solid_capstyle = 'butt' )
500- axis .plot (right_x , right_y , color = 'k' , linewidth = np .log2 (1 + lw [1 ]),
501- solid_joinstyle = 'miter' , solid_capstyle = 'butt' )
513+ if cmap != 'none' :
514+ axis .plot (left_x , left_y , color = sm .to_rgba (np .log2 (lw [0 ])),
515+ linewidth = np .log2 (1 + lw [0 ]),
516+ solid_joinstyle = 'miter' , solid_capstyle = 'butt' )
517+ axis .plot (right_x , right_y , color = sm .to_rgba (np .log2 (lw [1 ])),
518+ linewidth = np .log2 (1 + lw [1 ]),
519+ solid_joinstyle = 'miter' , solid_capstyle = 'butt' )
520+ else :
521+ axis .plot (left_x , left_y , color = 'k' ,
522+ linewidth = np .log2 (1 + lw [0 ]),
523+ solid_joinstyle = 'miter' , solid_capstyle = 'butt' )
524+ axis .plot (right_x , right_y , color = 'k' ,
525+ linewidth = np .log2 (1 + lw [1 ]),
526+ solid_joinstyle = 'miter' , solid_capstyle = 'butt' )
527+
502528 axis .plot (horizontal_x , horizontal_y , color = 'k' , linewidth = 1.0 ,
503529 solid_joinstyle = 'miter' , solid_capstyle = 'butt' )
504530
531+ if colorbar :
532+ cb = plt .colorbar (sm )
533+ cb .ax .set_ylabel ('log(Number of points)' )
534+
505535 axis .set_xticks ([])
506536 for side in ('right' , 'top' , 'bottom' ):
507537 axis .spines [side ].set_visible (False )
0 commit comments