44import shutil
55import logging
66
7+ import pandas as pd
78import scipy .stats
89import numpy as np
910import xarray as xr
1011import yaml
1112
12- from .base import init_benchmark , get_benchmark_samples , run_benchmark , load_benchmark_dataset
13+ from .base import init_benchmark , get_benchmark_samples , run_benchmark , load_benchmark_dataset , load_config
1314from .base import Simulator
1415
1516import batchglm .utils .stats as stat_utils
1617
1718logger = logging .getLogger (__name__ )
1819
1920
20- def plot_benchmark (root_dir : str , config_file = "config.yml" ):
21+ def group_by_training (benchmark_configs , keys = ("optim_algo" , "learning_rate" )):
22+ benchmark_df = pd .DataFrame .from_dict ({
23+ benchmark : [cfg ["training_args" ][k ] for k in keys ]
24+ for benchmark , cfg in benchmark_configs .items ()
25+ }, orient = 'index' , columns = keys )
26+
27+ return benchmark_df
28+
29+
30+ def plot_all_benchmarks (root_dir , config_file = "config.yml" ):
2131 logger .info ("loading config..." , end = "" , flush = True )
2232 config_file = os .path .join (root_dir , config_file )
2333 with open (config_file , mode = "r" ) as f :
@@ -28,48 +38,71 @@ def plot_benchmark(root_dir: str, config_file="config.yml"):
2838
2939 logger .info ("loading data..." , end = "" , flush = True )
3040 sim , benchmark_data = load_benchmark_dataset (root_dir )
31- benchmark_data .coords ["time_elapsed" ] = benchmark_data .time_elapsed .cumsum ("step" )
3241 logger .info ("\t [OK]" )
3342
43+ plot_benchmarks (
44+ plot_dir = plot_dir ,
45+ sim = sim ,
46+ benchmark_data = benchmark_data ,
47+ benchmark_names = benchmark_data .coords ["benchmark" ]
48+ )
49+
50+
51+ def plot_benchmarks (plot_dir : str , sim , benchmark_data , benchmark_names ):
52+ benchmark_data = benchmark_data .assign_coords (** {
53+ "time_elapsed" : benchmark_data .time_elapsed .cumsum ("step" ),
54+ })
55+ benchmark_data .coords ["benchmark" ] = xr .DataArray (
56+ dims = ("benchmark" ,),
57+ data = benchmark_names
58+ )
59+
60+ groupby_col = "benchmark"
61+ linewidth = 0.5
62+
3463 import plotnine as pn
3564 import matplotlib .pyplot as plt
65+ plt .rcParams ["legend.loc" ] = "center left"
3666
3767 from dask .diagnostics import ProgressBar
3868
3969 def plot_stat (val , val_name , name_prefix , scale_y_log10 = False ):
4070 with ProgressBar ():
41- df = val .to_dataframe (val_name ).reset_index ()
71+ df = val .to_dataframe (val_name )
72+ df = df .reset_index ()
4273
4374 plot = (pn .ggplot (df )
44- + pn .aes (x = "time_elapsed" , y = val_name , group = "benchmark" , color = "benchmark" )
75+ + pn .aes (x = "time_elapsed" , y = val_name , group = groupby_col , color = groupby_col )
4576 + pn .geom_line ()
46- + pn .geom_vline (xintercept = df .location [[np .argmin (df [val_name ])]].time_elapsed .values [0 ], color = "black" )
77+ + pn .geom_vline (xintercept = df .loc [[np .argmin (df [val_name ])]].time_elapsed .values [0 ], color = "black" )
4778 + pn .geom_hline (yintercept = np .min (df [val_name ]), alpha = 0.5 )
4879 )
4980 if scale_y_log10 :
5081 plot = plot + pn .scale_y_log10 ()
5182 plot .save (os .path .join (plot_dir , name_prefix + ".time.svg" ), format = "svg" )
5283
5384 plot = (pn .ggplot (df )
54- + pn .aes (x = "global_step" , y = val_name , group = "benchmark" , color = "benchmark" )
85+ + pn .aes (x = "global_step" , y = val_name , group = groupby_col , color = groupby_col )
5586 + pn .geom_line ()
56- + pn .geom_vline (xintercept = df .location [[np .argmin (df [val_name ])]].global_step .values [0 ], color = "black" )
87+ + pn .geom_vline (xintercept = df .loc [[np .argmin (df [val_name ])]].global_step .values [0 ], color = "black" )
5788 + pn .geom_hline (yintercept = np .min (df [val_name ]), alpha = 0.5 )
5889 )
5990 if scale_y_log10 :
6091 plot = plot + pn .scale_y_log10 ()
6192 plot .save (os .path .join (plot_dir , name_prefix + ".step.svg" ), format = "svg" )
6293
94+ return df
95+
6396 logger .info ("plotting..." )
6497 val : xr .DataArray = stat_utils .rmsd (
6598 np .exp (xr .DataArray (sim .params ["a" ][0 ], dims = ("features" ,))),
6699 np .exp (benchmark_data .a .isel (design_loc_params = 0 )), axis = [0 ])
67- plot_stat (val , "mapd" , "real_mu" )
100+ df = plot_stat (val , "mapd" , "real_mu" )
68101
69102 val : xr .DataArray = stat_utils .rmsd (
70103 np .exp (xr .DataArray (sim .params ["b" ][0 ], dims = ("features" ,))),
71104 np .exp (benchmark_data .b .isel (design_scale_params = 0 )), axis = [0 ])
72- plot_stat (val , "mapd" , "real_r" )
105+ df = plot_stat (val , "mapd" , "real_r" )
73106
74107 val : xr .DataArray = benchmark_data .loss
75108 plot_stat (val , "loss" , "loss" )
@@ -90,41 +123,65 @@ def plot_pval(window_size):
90123 t = t [:, window_size :]
91124 df = df [:, window_size :]
92125
93- pval = t .copy ()
94- pval [:, :] = scipy .stats .t (df ).cdf (t )
95- pval .plot .line (hue = "benchmark" )
96- plt .savefig (os .path .join (plot_dir , "pval_convergence.%dsteps.svg" % window_size ), format = "svg" )
97- # plt.show()
98- plt .close ()
126+ pval = xr .DataArray (
127+ name = "pval" ,
128+ data = scipy .stats .t (df ).cdf (t ),
129+ dims = t .dims ,
130+ coords = t .coords
131+ )
132+
133+ fig , ax = plt .subplots ()
134+ lines = pval .plot .line (hue = groupby_col , linewidth = linewidth , ax = ax )
135+ ax .get_legend ().set_bbox_to_anchor ((1 , 0.5 ))
136+ fig .savefig (os .path .join (plot_dir , "pval_convergence.%dsteps.svg" % window_size ),
137+ format = "svg" , bbox_inches = 'tight' )
138+ # fig.show()
139+ plt .close (fig )
99140
141+ plot_pval (25 )
142+ plot_pval (50 )
100143 plot_pval (100 )
101144 plot_pval (200 )
102145 plot_pval (400 )
103146
104- benchmark_data .full_loss .plot .line (hue = "benchmark" )
105- plt .savefig (os .path .join (plot_dir , "full_loss.svg" ), format = "svg" )
106- plt .close ()
147+ fig , ax = plt .subplots ()
148+ lines = benchmark_data .full_loss .plot .line (hue = groupby_col , linewidth = linewidth , ax = ax )
149+ ax .set_ylabel ('full loss' )
150+ ax .get_legend ().set_bbox_to_anchor ((1 , 0.5 ))
151+ fig .savefig (os .path .join (plot_dir , "full_loss.svg" ), format = "svg" , bbox_inches = 'tight' )
152+ plt .close (fig )
107153
108- benchmark_data .loss .plot .line (hue = "benchmark" )
109- plt .savefig (os .path .join (plot_dir , "batch_loss.svg" ), format = "svg" )
110- plt .close ()
154+ fig , ax = plt .subplots ()
155+ lines = benchmark_data .loss .plot .line (hue = groupby_col , linewidth = linewidth , ax = ax )
156+ ax .set_ylabel ('batch loss' )
157+ ax .get_legend ().set_bbox_to_anchor ((1 , 0.5 ))
158+ fig .savefig (os .path .join (plot_dir , "batch_loss.svg" ), format = "svg" , bbox_inches = 'tight' )
159+ plt .close (fig )
111160
112161 def plot_loss_rolling_mean (window_size ):
113162 logger .info ("plotting rolling mean of batch loss with window size: %d" % window_size )
114163
115- benchmark_data .loss .rolling (step = window_size ).mean ().plot .line (hue = "benchmark" )
116- plt .savefig (os .path .join (plot_dir , "batch_loss_rolling_mean.%dsteps.svg" % window_size ), format = "svg" )
117- plt .close ()
164+ fig , ax = plt .subplots ()
165+ lines = benchmark_data .loss .rolling (step = window_size ).mean ().plot .line (
166+ hue = groupby_col , linewidth = linewidth , ax = ax )
167+ ax .set_ylabel ('rolling mean' )
168+ ax .get_legend ().set_bbox_to_anchor ((1 , 0.5 ))
169+ fig .savefig (os .path .join (plot_dir , "batch_loss_rolling_mean.%dsteps.svg" % window_size ),
170+ format = "svg" , bbox_inches = 'tight' )
171+ plt .close (fig )
118172
119173 plot_loss_rolling_mean (25 )
120174 plot_loss_rolling_mean (50 )
121175 plot_loss_rolling_mean (100 )
122176 plot_loss_rolling_mean (200 )
123177
178+ fig , ax = plt .subplots ()
124179 with ProgressBar ():
125- benchmark_data .full_gradient .mean (dim = "features" ).plot .line (hue = "benchmark" )
126- plt .savefig (os .path .join (plot_dir , "mean_full_gradient.svg" ), format = "svg" )
127- plt .close ()
180+ lines = benchmark_data .full_gradient .mean (dim = "features" ).plot .line (
181+ hue = groupby_col , linewidth = linewidth , ax = ax )
182+ ax .get_legend ().set_bbox_to_anchor ((1 , 0.5 ))
183+ fig .savefig (os .path .join (plot_dir , "mean_full_gradient.svg" ), format = "svg" , bbox_inches = 'tight' )
184+ plt .close (fig )
128185
129186 logger .info ("ready" )
130187
@@ -214,7 +271,7 @@ def main():
214271 for smpl in benchmark_samples :
215272 logger .info (smpl )
216273 elif action == "plot" :
217- plot_benchmark (root_dir )
274+ plot_all_benchmarks (root_dir )
218275 elif action == "clean" :
219276 clean (root_dir )
220277
0 commit comments