22import pandas as pd
33import scipy .stats as st
44import matplotlib .pyplot as plt
5+ import matplotlib .lines as mlines
56
67from matplotlib import cm
78
@@ -96,12 +97,20 @@ def female_male_scatter(results,
9697 min_max = (- max_value_round ,
9798 max_value_round )
9899
100+ patches = []
99101 for i , log in enumerate (groups ):
100102 log = list (log )
101103 if i == 0 :
102104 al = 0.2
103105 else :
104106 al = 0.8
107+ p = mlines .Line2D ([], [],
108+ color = colors [i ],
109+ marker = 'o' ,
110+ label = group_names [i - 1 ],
111+ ms = 10 ,
112+ ls = '' )
113+ patches .append (p )
105114 ax .scatter (beta_female [log ],
106115 beta_male [log ],
107116 s = pvalue_diff [log ],
@@ -114,10 +123,7 @@ def female_male_scatter(results,
114123 ax .set_ylim (min_max )
115124 ax .axhline (y = 0 , color = 'k' )
116125 ax .axvline (x = 0 , color = 'k' )
117- ax .legend (group_names )
118- leg = ax .get_legend ()
119- for i in range (len (group_names )):
120- leg .legendHandles [i ].set_color (colors [i + 1 ])
126+ ax .legend (handles = patches )
121127
122128def female_male_forest (results ,
123129 colors ,
@@ -162,6 +168,7 @@ def female_male_forest(results,
162168 rotation = 90 )
163169 pos = pos + step
164170
171+ patches = []
165172 for sex in sexes :
166173 if sex == 'female' :
167174 sep = 0
@@ -183,29 +190,15 @@ def female_male_forest(results,
183190 round (interval [1 ][y_row ],3 )],
184191 [c ,c ],
185192 color = color )
186- ax .legend (sexes )
187- leg = ax .get_legend ()
188- leg .legendHandles [0 ].set_color (colors (0 ))
189- leg .legendHandles [1 ].set_color (colors (1 ))
190-
191- def female_male_volcano (results ,
192- ax ):
193- '''
194- Volcano plot with differences in betas between males and females
195- and pvalue diff
196-
197- Parameters
198- ----------
199- results: pd.DataFrame
200- single metabolite results
201- ax: plt.axes
202- ax to use for matplotlib
203- '''
204- diff = results ['Beta_female' ] - results ['Beta_male' ]
205- pval = - np .log10 (results ['pvalue_diff' ])
206-
207- ax .scatter (diff ,
208- pval )
193+ p = mlines .Line2D ([], [],
194+ color = color ,
195+ marker = 'o' ,
196+ label = sex ,
197+ ms = 10 ,
198+ ls = '-' )
199+ patches .append (p )
200+
201+ ax .legend (handles = patches )
209202
210203def score_plot (ax ,
211204 qtpad ,
0 commit comments