2020 import pandas as pd
2121 from matplotlib .axes import Axes
2222
23+ _STATISTIC_DISPLAY_NAME_MAPPING : dict [str , str ] = {
24+ 'x_mean' : 'X Mean' ,
25+ 'y_mean' : 'Y Mean' ,
26+ 'x_stdev' : 'X SD' ,
27+ 'y_stdev' : 'Y SD' ,
28+ 'x_median' : 'X Med.' ,
29+ 'y_median' : 'Y Med.' ,
30+ 'correlation' : 'Corr.' ,
31+ }
32+
2333
2434@plot_with_custom_style
2535def plot (
@@ -28,6 +38,7 @@ def plot(
2838 y_bounds : Iterable [Number ],
2939 save_to : str | Path ,
3040 decimals : int ,
41+ with_median : bool ,
3142 ** save_kwds : Any , # noqa: ANN401
3243) -> Axes | None :
3344 """
@@ -43,6 +54,8 @@ def plot(
4354 Path to save the plot frame to.
4455 decimals : int
4556 The number of integers to highlight as preserved.
57+ with_median : bool
58+ Whether to include the median.
4659 **save_kwds
4760 Additional keyword arguments that will be passed down to
4861 :meth:`matplotlib.figure.Figure.savefig`.
@@ -64,10 +77,24 @@ def plot(
6477 ax .xaxis .set_major_formatter (tick_formatter )
6578 ax .yaxis .set_major_formatter (tick_formatter )
6679
67- res = get_summary_statistics (data )
80+ res = get_summary_statistics (data , with_median = with_median )
81+
82+ if with_median :
83+ fields = (
84+ 'x_mean' ,
85+ 'x_median' ,
86+ 'x_stdev' ,
87+ 'y_mean' ,
88+ 'y_median' ,
89+ 'y_stdev' ,
90+ 'correlation' ,
91+ )
92+ locs = [0.9 , 0.78 , 0.66 , 0.5 , 0.38 , 0.26 , 0.1 ]
93+ else :
94+ fields = ('x_mean' , 'y_mean' , 'x_stdev' , 'y_stdev' , 'correlation' )
95+ locs = np .linspace (0.8 , 0.2 , num = len (fields ))
6896
69- labels = ('X Mean' , 'Y Mean' , 'X SD' , 'Y SD' , 'Corr.' )
70- locs = np .linspace (0.8 , 0.2 , num = len (labels ))
97+ labels = [_STATISTIC_DISPLAY_NAME_MAPPING [field ] for field in fields ]
7198 max_label_length = max ([len (label ) for label in labels ])
7299 max_stat = int (np .log10 (np .max (np .abs (res )))) + 1
73100 mean_x_digits , mean_y_digits = (
@@ -95,17 +122,23 @@ def plot(
95122 transform = ax .transAxes ,
96123 va = 'center' ,
97124 )
98- for label , loc , stat in zip (labels [:- 1 ], locs , res ):
99- add_stat_text (loc , formatter (label , stat ), alpha = 0.3 )
100- add_stat_text (loc , formatter (label , stat )[:- stat_clip ])
101-
102- correlation_str = corr_formatter (labels [- 1 ], res .correlation )
103- for alpha , text in zip ([0.3 , 1 ], [correlation_str , correlation_str [:- stat_clip ]]):
104- add_stat_text (
105- locs [- 1 ],
106- text ,
107- alpha = alpha ,
108- )
125+ for loc , field in zip (locs , fields ):
126+ label = _STATISTIC_DISPLAY_NAME_MAPPING [field ]
127+ stat = getattr (res , field )
128+
129+ if field == 'correlation' :
130+ correlation_str = corr_formatter (labels [- 1 ], res .correlation )
131+ for alpha , text in zip (
132+ [0.3 , 1 ], [correlation_str , correlation_str [:- stat_clip ]]
133+ ):
134+ add_stat_text (
135+ locs [- 1 ],
136+ text ,
137+ alpha = alpha ,
138+ )
139+ else :
140+ add_stat_text (loc , formatter (label , stat ), alpha = 0.3 )
141+ add_stat_text (loc , formatter (label , stat )[:- stat_clip ])
109142
110143 if not save_to :
111144 return ax
0 commit comments