@@ -1002,43 +1002,24 @@ def __init__(self, ax, labels, actives=None):
10021002 if actives is None :
10031003 actives = [False ] * len (labels )
10041004
1005- if len (labels ) > 1 :
1006- dy = 1. / (len (labels ) + 1 )
1007- ys = np .linspace (1 - dy , dy , len (labels ))
1008- else :
1009- dy = 0.25
1010- ys = [0.5 ]
1011-
1012- axcolor = ax .get_facecolor ()
1013-
1014- self .labels = []
1015- self .lines = []
1016- self .rectangles = []
1017-
1018- lineparams = {'color' : 'k' , 'linewidth' : 1.25 ,
1019- 'transform' : ax .transAxes , 'solid_capstyle' : 'butt' }
1020- for y , label , active in zip (ys , labels , actives ):
1021- t = ax .text (0.25 , y , label , transform = ax .transAxes ,
1022- horizontalalignment = 'left' ,
1023- verticalalignment = 'center' )
1024-
1025- w , h = dy / 2 , dy / 2
1026- x , y = 0.05 , y - h / 2
1027-
1028- p = Rectangle (xy = (x , y ), width = w , height = h , edgecolor = 'black' ,
1029- facecolor = axcolor , transform = ax .transAxes )
1005+ ys = np .linspace (1 , 0 , len (labels )+ 2 )[1 :- 1 ]
1006+ text_size = mpl .rcParams ["font.size" ] / 2
10301007
1031- l1 = Line2D ([x , x + w ], [y + h , y ], ** lineparams )
1032- l2 = Line2D ([x , x + w ], [y , y + h ], ** lineparams )
1008+ self .labels = [
1009+ ax .text (0.25 , y , label , transform = ax .transAxes ,
1010+ horizontalalignment = "left" , verticalalignment = "center" )
1011+ for y , label in zip (ys , labels )]
10331012
1034- l1 .set_visible (active )
1035- l2 .set_visible (active )
1036- self .labels .append (t )
1037- self .rectangles .append (p )
1038- self .lines .append ((l1 , l2 ))
1039- ax .add_patch (p )
1040- ax .add_line (l1 )
1041- ax .add_line (l2 )
1013+ self ._squares = ax .scatter (
1014+ [0.15 ] * len (ys ), ys , marker = 's' , c = "none" , linewidth = 1 ,
1015+ transform = ax .transAxes , edgecolor = "k"
1016+ )
1017+ mask = [not x for x in actives ]
1018+ self ._crosses = ax .scatter (
1019+ [0.15 ] * len (ys ), ys , marker = 'x' , linewidth = 1 ,
1020+ c = ["k" if actives [i ] else "none" for i in range (len (ys ))],
1021+ transform = ax .transAxes
1022+ )
10421023
10431024 self .connect_event ('button_press_event' , self ._clicked )
10441025
@@ -1047,11 +1028,29 @@ def __init__(self, ax, labels, actives=None):
10471028 def _clicked (self , event ):
10481029 if self .ignore (event ) or event .button != 1 or event .inaxes != self .ax :
10491030 return
1050- for i , (p , t ) in enumerate (zip (self .rectangles , self .labels )):
1051- if (t .get_window_extent ().contains (event .x , event .y ) or
1052- p .get_window_extent ().contains (event .x , event .y )):
1053- self .set_active (i )
1054- break
1031+ pclicked = self .ax .transAxes .inverted ().transform ((event .x , event .y ))
1032+ _ , square_inds = self ._squares .contains (event )
1033+ coords = self ._squares .get_offset_transform ().transform (
1034+ self ._squares .get_offsets ()
1035+ )
1036+ distances = {}
1037+ if hasattr (self , "_rectangles" ):
1038+ for i , (p , t ) in enumerate (zip (self ._rectangles , self .labels )):
1039+ if (t .get_window_extent ().contains (event .x , event .y )
1040+ or (
1041+ p .get_x () < event .x < p .get_x () + p .get_width ()
1042+ and p .get_y () < event .y < p .get_y ()
1043+ + p .get_height ()
1044+ )):
1045+ distances [i ] = np .linalg .norm (pclicked - p .get_center ())
1046+ else :
1047+ for i , t in enumerate (self .labels ):
1048+ if (i in square_inds ["ind" ]
1049+ or t .get_window_extent ().contains (event .x , event .y )):
1050+ distances [i ] = np .linalg .norm (pclicked - coords [i ])
1051+ if len (distances ) > 0 :
1052+ closest = min (distances , key = distances .get )
1053+ self .set_active (closest )
10551054
10561055 def set_active (self , index ):
10571056 """
@@ -1072,9 +1071,18 @@ def set_active(self, index):
10721071 if index not in range (len (self .labels )):
10731072 raise ValueError (f'Invalid CheckButton index: { index } ' )
10741073
1075- l1 , l2 = self .lines [index ]
1076- l1 .set_visible (not l1 .get_visible ())
1077- l2 .set_visible (not l2 .get_visible ())
1074+ if colors .same_color (
1075+ self ._crosses .get_facecolor ()[index ], colors .to_rgba ("none" )
1076+ ):
1077+ self ._crosses .get_facecolor ()[index ] = colors .to_rgba ("k" )
1078+ else :
1079+ self ._crosses .get_facecolor ()[index ] = colors .to_rgba ("none" )
1080+
1081+ if hasattr (self , "_rectangles" ):
1082+ for i , p in enumerate (self ._rectangles ):
1083+ p .set_facecolor ("k" if colors .same_color (
1084+ p .get_facecolor (), colors .to_rgba ("none" ))
1085+ else "none" )
10781086
10791087 if self .drawon :
10801088 self .ax .figure .canvas .draw ()
@@ -1086,7 +1094,9 @@ def get_status(self):
10861094 """
10871095 Return a tuple of the status (True/False) of all of the check buttons.
10881096 """
1089- return [l1 .get_visible () for (l1 , l2 ) in self .lines ]
1097+ return [False if colors .same_color (
1098+ self ._crosses .get_facecolors ()[i ], colors .to_rgba ("none" ))
1099+ else True for i in range (len (self .labels ))]
10901100
10911101 def on_clicked (self , func ):
10921102 """
@@ -1100,6 +1110,24 @@ def disconnect(self, cid):
11001110 """Remove the observer with connection id *cid*."""
11011111 self ._observers .disconnect (cid )
11021112
1113+ @property
1114+ def rectangles (self ):
1115+ if not hasattr (self , "rectangles" ):
1116+ dy = 1. / (len (self .labels ) + 1 )
1117+ w , h = dy / 2 , dy / 2
1118+ rectangles = self ._rectangles = [
1119+ Rectangle (xy = self ._squares .get_offsets ()[i ], width = w , height = h ,
1120+ edgecolor = "black" ,
1121+ facecolor = self ._squares .get_facecolor ()[i ],
1122+ transform = self .ax .transAxes
1123+ )
1124+ for i in range (len (self .labels ))
1125+ ]
1126+ self ._squares .set_visible (False )
1127+ for rectangle in rectangles :
1128+ self .ax .add_patch (rectangle )
1129+ return self ._rectangles
1130+
11031131
11041132class TextBox (AxesWidget ):
11051133 """
0 commit comments