Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 89 additions & 34 deletions statannot/statannot.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# pip install statannot

import warnings

import matplotlib.pyplot as plt
Expand All @@ -8,6 +10,7 @@
import pandas as pd
import seaborn as sns
from seaborn.utils import remove_na
import pdb

from .utils import raise_expected_got, assert_is_in
from .StatResult import StatResult
Expand Down Expand Up @@ -41,6 +44,7 @@ def stat_test(
- `t-test_paired`
- `Wilcoxon`
- `Kruskal`
- `Chi squared`
comparisons_correction: str or None, default None
Method to use for multiple comparisons correction. Currently only the
Bonferroni correction is implemented.
Expand Down Expand Up @@ -137,6 +141,16 @@ def stat_test(
result = StatResult(
'Kruskal-Wallis paired samples', 'Kruskal', 'stat', stat, pval
)

elif test == 'chisquare':
stat, pval = stats.chisquare([box_data1.count(), box_data2.count()],
**stats_params)
test_short_name = 'ChiSquare'
result = StatResult(
'ChiSquare categorical groups', 'ChiSquare', 'stat', stat, pval
)


else:
result = StatResult(None, '', None, None, np.nan)

Expand Down Expand Up @@ -263,8 +277,8 @@ def simple_text(pval, pvalue_format, pvalue_thresholds, test_short_name=None):

return text + pval_text


def add_stat_annotation(ax, plot='boxplot',
# ='boxplot' removed after the word plot
def add_stat_annotation(ax, plot,
data=None, x=None, y=None, hue=None, units=None, order=None,
hue_order=None, box_pairs=None, width=0.8,
perform_stat_test=True,
Expand All @@ -279,24 +293,30 @@ def add_stat_annotation(ax, plot='boxplot',
color='0.2', linewidth=1.5,
fontsize='medium', verbose=1):
"""
Optionally computes statistical test between pairs of data series, and add statistical annotation on top
of the boxes/bars. The same exact arguments `data`, `x`, `y`, `hue`, `order`, `width`,
`hue_order` (and `units`) as in the seaborn boxplot/barplot function must be passed to this function.

This function works in one of the two following modes:
a) `perform_stat_test` is True: statistical test as given by argument `test` is performed.
b) `perform_stat_test` is False: no statistical test is performed, list of custom p-values `pvalues` are
used for each pair of boxes. The `test_short_name` argument is then used as the name of the
custom statistical test.

:param plot: type of the plot, one of 'boxplot' or 'barplot'.
:param line_height: in axes fraction coordinates
:param text_offset: in points
:param box_pairs: can be of either form: For non-grouped boxplot: `[(cat1, cat2), (cat3, cat4)]`. For boxplot grouped by hue: `[((cat1, hue1), (cat2, hue2)), ((cat3, hue3), (cat4, hue4))]`
:param pvalue_format_string: defaults to `"{.3e}"`
:param pvalue_thresholds: list of lists, or tuples. Default is: For "star" text_format: `[[1e-4, "****"], [1e-3, "***"], [1e-2, "**"], [0.05, "*"], [1, "ns"]]`. For "simple" text_format : `[[1e-5, "1e-5"], [1e-4, "1e-4"], [1e-3, "0.001"], [1e-2, "0.01"]]`
:param pvalues: list or array of p-values for each box pair comparison.
:param comparisons_correction: Method for multiple comparisons correction. `bonferroni` or None.
Optionally computes statistical test between pairs of data series, and add
statistical annotation on top of the boxes/bars. The same exact arguments
`data`, `x`, `y`, `hue`, `order`, `width`, `hue_order` (and `units`) as in
the seaborn boxplot/barplot function must be passed to this function.

This function works in one of the two following modes: a)
`perform_stat_test` is True: statistical test as given by argument `test`
is performed. b) `perform_stat_test` is False: no statistical test is
performed, list of custom p-values `pvalues` are used for each pair of
boxes. The `test_short_name` argument is then used as the name of the
custom statistical test.

:param plot: type of the plot, one of 'boxplot' or 'barplot'. :param
line_height: in axes fraction coordinates :param text_offset: in points
:param box_pairs: can be of either form: For non-grouped boxplot: `[(cat1,
cat2), (cat3, cat4)]`. For boxplot grouped by hue: `[((cat1, hue1), (cat2,
hue2)), ((cat3, hue3), (cat4, hue4))]` :param pvalue_format_string:
defaults to `"{.3e}"` :param pvalue_thresholds: list of lists, or tuples.
Default is: For "star" text_format: `[[1e-4, "****"], [1e-3, "***"], [1e-2,
"**"], [0.05, "*"], [1, "ns"]]`. For "simple" text_format : `[[1e-5,
"1e-5"], [1e-4, "1e-4"], [1e-3, "0.001"], [1e-2, "0.01"]]` :param
pvalues: list or array of p-values for each box pair comparison. :param
comparisons_correction: Method for multiple comparisons correction.
`bonferroni` or None.
"""

def find_x_position_box(box_plotter, boxName):
Expand Down Expand Up @@ -348,8 +368,10 @@ def get_box_data(box_plotter, boxName):

if pvalue_thresholds is DEFAULT:
if text_format == "star":
pvalue_thresholds = [[1e-4, "****"], [1e-3, "***"],
[1e-2, "**"], [0.05, "*"], [1, "ns"]]
pvalue_thresholds = [[0.0001, r"${****}$"], [0.001, r"${***}$"],
[0.01, r"${**}$"], [0.05, r"$*$"], [1, "ns"]]
# [[0.0001, r"$\ast\ast\ast\ast$"], [0.001, r"$\ast\ast\ast$"],
# [0.01, r"${\ast\ast}$"], [0.05, r"$\ast*$"], [1, "ns"]]
else:
pvalue_thresholds = [[1e-5, "1e-5"], [1e-4, "1e-4"],
[1e-3, "0.001"], [1e-2, "0.01"]]
Expand All @@ -365,7 +387,7 @@ def get_box_data(box_plotter, boxName):
"or `test_short_name` must be `None`.")
valid_list = ['t-test_ind', 't-test_welch', 't-test_paired',
'Mann-Whitney', 'Mann-Whitney-gt', 'Mann-Whitney-ls',
'Levene', 'Wilcoxon', 'Kruskal']
'Levene', 'Wilcoxon', 'Kruskal', 'chisquare']
if test not in valid_list:
raise ValueError("test value should be one of the following: {}."
.format(', '.join(valid_list)))
Expand Down Expand Up @@ -433,31 +455,51 @@ def get_box_data(box_plotter, boxName):
box_plotter = sns.categorical._BoxPlotter(
x, y, hue, data, order, hue_order, orient=None, width=width, color=None,
palette=None, saturation=.75, dodge=True, fliersize=5, linewidth=None)

elif plot == 'barplot':
# Create the same plotter object as seaborn's barplot
box_plotter = sns.categorical._BarPlotter(
x, y, hue, data, order, hue_order,
estimator=np.mean, ci=95, n_boot=1000, units=None,
estimator=np.mean, ci=95, n_boot=1000, units=None, seed=None,
orient=None, color=None, palette=None, saturation=.75,
errcolor=".26", errwidth=None, capsize=None, dodge=True)

elif plot == 'countplot':
# Create the same plotter object as seaborn's countplot
box_plotter = sns.categorical._CountPlotter(
x, y, hue, data, order, hue_order,
estimator=np.mean, ci=95, n_boot=1000, units=None, seed=None,
orient=None, color=None, palette=None, saturation=.75,
errcolor=".26", errwidth=None, capsize=None, dodge=True)

# Build the list of box data structures with the x and ymax positions
group_names = box_plotter.group_names
hue_names = box_plotter.hue_names

if box_plotter.plot_hues is None:
box_names = group_names
labels = box_names
else:
box_names = [(group_name, hue_name) for group_name in group_names for hue_name in hue_names]
labels = ['{}_{}'.format(group_name, hue_name) for (group_name, hue_name) in box_names]

box_structs = [{'box':box_names[i],
if test == 'chisquare':
box_structs = [{'box':box_names[i],
'label':labels[i],
'x':find_x_position_box(box_plotter, box_names[i]),
'box_data':get_box_data(box_plotter, box_names[i]),
'ymax':np.amax(get_box_data(box_plotter, box_names[i])) if
'ymax':np.amax(get_box_data(box_plotter, box_names[i]).count()) if
len(get_box_data(box_plotter, box_names[i])) > 0 else np.nan}
for i in range(len(box_names))]
else:
box_structs = [{'box':box_names[i],
'label':labels[i],
'x':find_x_position_box(box_plotter, box_names[i]),
'box_data':get_box_data(box_plotter, box_names[i]),
'ymax':np.amax(get_box_data(box_plotter, box_names[i])) if
len(get_box_data(box_plotter, box_names[i])) > 0 else np.nan}
for i in range(len(box_names))]

# Sort the box data structures by position along the x axis
box_structs = sorted(box_structs, key=lambda x: x['x'])
# Add the index position in the list of boxes along the x axis
Expand All @@ -467,6 +509,7 @@ def get_box_data(box_plotter, boxName):

# Build the list of box data structure pairs
box_struct_pairs = []

for i_box_pair, (box1, box2) in enumerate(box_pairs):
valid = box1 in box_names and box2 in box_names
if not valid:
Expand All @@ -480,11 +523,11 @@ def get_box_data(box_plotter, boxName):
else:
pair = (box_struct2, box_struct1)
box_struct_pairs.append(pair)

# Draw first the annotations with the shortest between-boxes distance, in order to reduce
# overlapping between annotations.
box_struct_pairs = sorted(box_struct_pairs, key=lambda x: abs(x[1]['x'] - x[0]['x']))

# Build array that contains the x and y_max position of the highest annotation or box data at
# a given x position, and also keeps track of the number of stacked annotations.
# This array will be updated when a new annotation is drawn.
Expand All @@ -499,7 +542,7 @@ def get_box_data(box_plotter, boxName):
y_stack = []

for box_struct1, box_struct2 in box_struct_pairs:

box1 = box_struct1['box']
box2 = box_struct2['box']
label1 = box_struct1['label']
Expand Down Expand Up @@ -542,6 +585,14 @@ def get_box_data(box_plotter, boxName):
result.box2 = box2
test_result_list.append(result)

# Don't plot lines that are not significantly different to only plot significant bars
# (https://github.com/webermarcolivier/statannot/issues/25)
if result.pval > 0.05:
print(result.box1, 'and' ,result.box2, 'did not show significant differences and the p value = {}'.format(result.pval))
continue
else:
print(result.box1, 'and' ,result.box2, 'did show significant differences and the p value = {}'.format(result.pval))

if verbose >= 1:
print("{} v.s. {}: {}".format(label1, label2, result.formatted_output))

Expand Down Expand Up @@ -624,10 +675,14 @@ def get_box_data(box_plotter, boxName):
# Increment the counter of annotations in the y_stack array
y_stack_arr[2, xi1:xi2 + 1] = y_stack_arr[2, xi1:xi2 + 1] + 1

y_stack_max = max(ymaxs)
if loc == 'inside':
ax.set_ylim((ylim[0], max(1.03*y_stack_max, ylim[1])))
elif loc == 'outside':
ax.set_ylim((ylim[0], ylim[1]))
# Check to see if there are actual significant differences
if len(ymaxs) == 0:
pass
else:
y_stack_max = max(ymaxs)
if loc == 'inside':
ax.set_ylim((ylim[0], max(1.03*y_stack_max, ylim[1])))
elif loc == 'outside':
ax.set_ylim((ylim[0], ylim[1]))

return ax, test_result_list