From b70e27221d48897464311171076ddfa5eb93b180 Mon Sep 17 00:00:00 2001 From: Christopher Creveling Date: Thu, 30 Sep 2021 10:20:19 -0600 Subject: [PATCH] Change to statanot file to allow for asteristks to be the same as in LaTeX and the feature change to allow chi-squared tests on categorical data --- statannot/statannot.py | 123 +++++++++++++++++++++++++++++------------ 1 file changed, 89 insertions(+), 34 deletions(-) diff --git a/statannot/statannot.py b/statannot/statannot.py index 05e3009..1c9dfab 100644 --- a/statannot/statannot.py +++ b/statannot/statannot.py @@ -1,3 +1,5 @@ +# pip install statannot + import warnings import matplotlib.pyplot as plt @@ -8,6 +10,7 @@ import pandas as pd import seaborn as sns from seaborn.utils import remove_na +import pdb from .utils import raise_expected_got, assert_is_in from .StatResult import StatResult @@ -41,6 +44,7 @@ def stat_test( - `t-test_paired` - `Wilcoxon` - `Kruskal` + - `Chi squared` comparisons_correction: str or None, default None Method to use for multiple comparisons correction. Currently only the Bonferroni correction is implemented. @@ -137,6 +141,16 @@ def stat_test( result = StatResult( 'Kruskal-Wallis paired samples', 'Kruskal', 'stat', stat, pval ) + + elif test == 'chisquare': + stat, pval = stats.chisquare([box_data1.count(), box_data2.count()], + **stats_params) + test_short_name = 'ChiSquare' + result = StatResult( + 'ChiSquare categorical groups', 'ChiSquare', 'stat', stat, pval + ) + + else: result = StatResult(None, '', None, None, np.nan) @@ -263,8 +277,8 @@ def simple_text(pval, pvalue_format, pvalue_thresholds, test_short_name=None): return text + pval_text - -def add_stat_annotation(ax, plot='boxplot', +# ='boxplot' removed after the word plot +def add_stat_annotation(ax, plot, data=None, x=None, y=None, hue=None, units=None, order=None, hue_order=None, box_pairs=None, width=0.8, perform_stat_test=True, @@ -279,24 +293,30 @@ def add_stat_annotation(ax, plot='boxplot', color='0.2', linewidth=1.5, fontsize='medium', verbose=1): """ - Optionally computes statistical test between pairs of data series, and add statistical annotation on top - of the boxes/bars. The same exact arguments `data`, `x`, `y`, `hue`, `order`, `width`, - `hue_order` (and `units`) as in the seaborn boxplot/barplot function must be passed to this function. - - This function works in one of the two following modes: - a) `perform_stat_test` is True: statistical test as given by argument `test` is performed. - b) `perform_stat_test` is False: no statistical test is performed, list of custom p-values `pvalues` are - used for each pair of boxes. The `test_short_name` argument is then used as the name of the - custom statistical test. - - :param plot: type of the plot, one of 'boxplot' or 'barplot'. - :param line_height: in axes fraction coordinates - :param text_offset: in points - :param box_pairs: can be of either form: For non-grouped boxplot: `[(cat1, cat2), (cat3, cat4)]`. For boxplot grouped by hue: `[((cat1, hue1), (cat2, hue2)), ((cat3, hue3), (cat4, hue4))]` - :param pvalue_format_string: defaults to `"{.3e}"` - :param pvalue_thresholds: list of lists, or tuples. Default is: For "star" text_format: `[[1e-4, "****"], [1e-3, "***"], [1e-2, "**"], [0.05, "*"], [1, "ns"]]`. For "simple" text_format : `[[1e-5, "1e-5"], [1e-4, "1e-4"], [1e-3, "0.001"], [1e-2, "0.01"]]` - :param pvalues: list or array of p-values for each box pair comparison. - :param comparisons_correction: Method for multiple comparisons correction. `bonferroni` or None. + Optionally computes statistical test between pairs of data series, and add + statistical annotation on top of the boxes/bars. The same exact arguments + `data`, `x`, `y`, `hue`, `order`, `width`, `hue_order` (and `units`) as in + the seaborn boxplot/barplot function must be passed to this function. + + This function works in one of the two following modes: a) + `perform_stat_test` is True: statistical test as given by argument `test` + is performed. b) `perform_stat_test` is False: no statistical test is + performed, list of custom p-values `pvalues` are used for each pair of + boxes. The `test_short_name` argument is then used as the name of the + custom statistical test. + + :param plot: type of the plot, one of 'boxplot' or 'barplot'. :param + line_height: in axes fraction coordinates :param text_offset: in points + :param box_pairs: can be of either form: For non-grouped boxplot: `[(cat1, + cat2), (cat3, cat4)]`. For boxplot grouped by hue: `[((cat1, hue1), (cat2, + hue2)), ((cat3, hue3), (cat4, hue4))]` :param pvalue_format_string: + defaults to `"{.3e}"` :param pvalue_thresholds: list of lists, or tuples. + Default is: For "star" text_format: `[[1e-4, "****"], [1e-3, "***"], [1e-2, + "**"], [0.05, "*"], [1, "ns"]]`. For "simple" text_format : `[[1e-5, + "1e-5"], [1e-4, "1e-4"], [1e-3, "0.001"], [1e-2, "0.01"]]` :param + pvalues: list or array of p-values for each box pair comparison. :param + comparisons_correction: Method for multiple comparisons correction. + `bonferroni` or None. """ def find_x_position_box(box_plotter, boxName): @@ -348,8 +368,10 @@ def get_box_data(box_plotter, boxName): if pvalue_thresholds is DEFAULT: if text_format == "star": - pvalue_thresholds = [[1e-4, "****"], [1e-3, "***"], - [1e-2, "**"], [0.05, "*"], [1, "ns"]] + pvalue_thresholds = [[0.0001, r"${****}$"], [0.001, r"${***}$"], + [0.01, r"${**}$"], [0.05, r"$*$"], [1, "ns"]] + # [[0.0001, r"$\ast\ast\ast\ast$"], [0.001, r"$\ast\ast\ast$"], + # [0.01, r"${\ast\ast}$"], [0.05, r"$\ast*$"], [1, "ns"]] else: pvalue_thresholds = [[1e-5, "1e-5"], [1e-4, "1e-4"], [1e-3, "0.001"], [1e-2, "0.01"]] @@ -365,7 +387,7 @@ def get_box_data(box_plotter, boxName): "or `test_short_name` must be `None`.") valid_list = ['t-test_ind', 't-test_welch', 't-test_paired', 'Mann-Whitney', 'Mann-Whitney-gt', 'Mann-Whitney-ls', - 'Levene', 'Wilcoxon', 'Kruskal'] + 'Levene', 'Wilcoxon', 'Kruskal', 'chisquare'] if test not in valid_list: raise ValueError("test value should be one of the following: {}." .format(', '.join(valid_list))) @@ -433,17 +455,27 @@ def get_box_data(box_plotter, boxName): box_plotter = sns.categorical._BoxPlotter( x, y, hue, data, order, hue_order, orient=None, width=width, color=None, palette=None, saturation=.75, dodge=True, fliersize=5, linewidth=None) + elif plot == 'barplot': # Create the same plotter object as seaborn's barplot box_plotter = sns.categorical._BarPlotter( x, y, hue, data, order, hue_order, - estimator=np.mean, ci=95, n_boot=1000, units=None, + estimator=np.mean, ci=95, n_boot=1000, units=None, seed=None, + orient=None, color=None, palette=None, saturation=.75, + errcolor=".26", errwidth=None, capsize=None, dodge=True) + + elif plot == 'countplot': + # Create the same plotter object as seaborn's countplot + box_plotter = sns.categorical._CountPlotter( + x, y, hue, data, order, hue_order, + estimator=np.mean, ci=95, n_boot=1000, units=None, seed=None, orient=None, color=None, palette=None, saturation=.75, errcolor=".26", errwidth=None, capsize=None, dodge=True) # Build the list of box data structures with the x and ymax positions group_names = box_plotter.group_names hue_names = box_plotter.hue_names + if box_plotter.plot_hues is None: box_names = group_names labels = box_names @@ -451,13 +483,23 @@ def get_box_data(box_plotter, boxName): box_names = [(group_name, hue_name) for group_name in group_names for hue_name in hue_names] labels = ['{}_{}'.format(group_name, hue_name) for (group_name, hue_name) in box_names] - box_structs = [{'box':box_names[i], + if test == 'chisquare': + box_structs = [{'box':box_names[i], 'label':labels[i], 'x':find_x_position_box(box_plotter, box_names[i]), 'box_data':get_box_data(box_plotter, box_names[i]), - 'ymax':np.amax(get_box_data(box_plotter, box_names[i])) if + 'ymax':np.amax(get_box_data(box_plotter, box_names[i]).count()) if len(get_box_data(box_plotter, box_names[i])) > 0 else np.nan} for i in range(len(box_names))] + else: + box_structs = [{'box':box_names[i], + 'label':labels[i], + 'x':find_x_position_box(box_plotter, box_names[i]), + 'box_data':get_box_data(box_plotter, box_names[i]), + 'ymax':np.amax(get_box_data(box_plotter, box_names[i])) if + len(get_box_data(box_plotter, box_names[i])) > 0 else np.nan} + for i in range(len(box_names))] + # Sort the box data structures by position along the x axis box_structs = sorted(box_structs, key=lambda x: x['x']) # Add the index position in the list of boxes along the x axis @@ -467,6 +509,7 @@ def get_box_data(box_plotter, boxName): # Build the list of box data structure pairs box_struct_pairs = [] + for i_box_pair, (box1, box2) in enumerate(box_pairs): valid = box1 in box_names and box2 in box_names if not valid: @@ -480,11 +523,11 @@ def get_box_data(box_plotter, boxName): else: pair = (box_struct2, box_struct1) box_struct_pairs.append(pair) - + # Draw first the annotations with the shortest between-boxes distance, in order to reduce # overlapping between annotations. box_struct_pairs = sorted(box_struct_pairs, key=lambda x: abs(x[1]['x'] - x[0]['x'])) - + # Build array that contains the x and y_max position of the highest annotation or box data at # a given x position, and also keeps track of the number of stacked annotations. # This array will be updated when a new annotation is drawn. @@ -499,7 +542,7 @@ def get_box_data(box_plotter, boxName): y_stack = [] for box_struct1, box_struct2 in box_struct_pairs: - + box1 = box_struct1['box'] box2 = box_struct2['box'] label1 = box_struct1['label'] @@ -542,6 +585,14 @@ def get_box_data(box_plotter, boxName): result.box2 = box2 test_result_list.append(result) + # Don't plot lines that are not significantly different to only plot significant bars + # (https://github.com/webermarcolivier/statannot/issues/25) + if result.pval > 0.05: + print(result.box1, 'and' ,result.box2, 'did not show significant differences and the p value = {}'.format(result.pval)) + continue + else: + print(result.box1, 'and' ,result.box2, 'did show significant differences and the p value = {}'.format(result.pval)) + if verbose >= 1: print("{} v.s. {}: {}".format(label1, label2, result.formatted_output)) @@ -624,10 +675,14 @@ def get_box_data(box_plotter, boxName): # Increment the counter of annotations in the y_stack array y_stack_arr[2, xi1:xi2 + 1] = y_stack_arr[2, xi1:xi2 + 1] + 1 - y_stack_max = max(ymaxs) - if loc == 'inside': - ax.set_ylim((ylim[0], max(1.03*y_stack_max, ylim[1]))) - elif loc == 'outside': - ax.set_ylim((ylim[0], ylim[1])) + # Check to see if there are actual significant differences + if len(ymaxs) == 0: + pass + else: + y_stack_max = max(ymaxs) + if loc == 'inside': + ax.set_ylim((ylim[0], max(1.03*y_stack_max, ylim[1]))) + elif loc == 'outside': + ax.set_ylim((ylim[0], ylim[1])) return ax, test_result_list