22
33import os
44import shutil
5+ import logging
56
67import numpy as np
78import xarray as xr
89import yaml
910
10- # import batchglm.pkg_constants
11+ from .base import load_benchmark_dataset , get_benchmark_samples
12+
1113from batchglm .api .models .nb_glm import Simulator , Estimator
1214
1315import batchglm .utils .stats as stat_utils
1416
17+ logger = logging .getLogger (__name__ )
18+
1519
1620def init_benchmark (
1721 root_dir : str ,
@@ -99,13 +103,6 @@ def prepare_benchmark_sample(
99103 return sample_config
100104
101105
102- def get_benchmark_samples (root_dir : str , config_file = "config.yml" ):
103- config_file = os .path .join (root_dir , config_file )
104- with open (config_file , mode = "r" ) as f :
105- config = yaml .load (f )
106- return list (config ["benchmark_samples" ].keys ())
107-
108-
109106def run_benchmark (root_dir : str , sample : str , config_file = "config.yml" ):
110107 config_file = os .path .join (root_dir , config_file )
111108 with open (config_file , mode = "r" ) as f :
@@ -122,58 +119,31 @@ def run_benchmark(root_dir: str, sample: str, config_file="config.yml"):
122119 init_args = sample_config ["init_args" ]
123120 init_args ["working_dir" ] = working_dir
124121
125- print ("loading data..." , end = "" , flush = True )
122+ logger . info ("loading data..." , end = "" , flush = True )
126123 sim = Simulator ()
127124 sim .load (sim_data_file )
128- print ("\t [OK]" )
125+ logger . info ("\t [OK]" )
129126
130- print ("starting estimation of benchmark sample '%s'..." % sample )
127+ logger . info ("starting estimation of benchmark sample '%s'..." % sample )
131128 estimator = Estimator (sim .input_data , batch_size = batch_size )
132129 estimator .initialize (** init_args )
133130 estimator .train (learning_rate = learning_rate )
134- print ("estimation of benchmark sample '%s' ready" % sample )
135-
136-
137- def load_benchmark_dataset (root_dir : str , config_file = "config.yml" ) -> Tuple [Simulator , xr .Dataset ]:
138- config_file = os .path .join (root_dir , config_file )
139- with open (config_file , mode = "r" ) as f :
140- config = yaml .load (f )
141-
142- sim_data_file = os .path .join (root_dir , config ["sim_data" ])
143- sim = Simulator ()
144- sim .load (sim_data_file )
145-
146- benchmark_samples = config ["benchmark_samples" ]
147- benchmark_data = []
148- for smpl , cfg in benchmark_samples .items ():
149- data = xr .open_mfdataset (
150- os .path .join (root_dir , cfg ["working_dir" ], "estimation-*.h5" ),
151- engine = "netcdf4" ,
152- concat_dim = "step" ,
153- autoclose = True ,
154- parallel = True ,
155- )
156- data = data .sortby ("global_step" )
157- data .coords ["benchmark" ] = smpl
158- benchmark_data .append (data )
159- benchmark_data = xr .auto_combine (benchmark_data , concat_dim = "benchmark" , coords = "all" )
160-
161- return sim , benchmark_data
131+ logger .info ("estimation of benchmark sample '%s' ready" % sample )
162132
163133
164134def plot_benchmark (root_dir : str , config_file = "config.yml" ):
165- print ("loading config..." , end = "" , flush = True )
135+ logger . info ("loading config..." , end = "" , flush = True )
166136 config_file = os .path .join (root_dir , config_file )
167137 with open (config_file , mode = "r" ) as f :
168138 config = yaml .load (f )
169- print ("\t [OK]" )
139+ logger . info ("\t [OK]" )
170140
171141 plot_dir = os .path .join (root_dir , config ["plot_dir" ])
172142
173- print ("loading data..." , end = "" , flush = True )
143+ logger . info ("loading data..." , end = "" , flush = True )
174144 sim , benchmark_data = load_benchmark_dataset (root_dir )
175145 benchmark_data .coords ["time_elapsed" ] = benchmark_data .time_elapsed .cumsum ("step" )
176- print ("\t [OK]" )
146+ logger . info ("\t [OK]" )
177147
178148 import plotnine as pn
179149 import matplotlib .pyplot as plt
@@ -204,7 +174,7 @@ def plot_stat(val, val_name, name_prefix, scale_y_log10=False):
204174 plot = plot + pn .scale_y_log10 ()
205175 plot .save (os .path .join (plot_dir , name_prefix + ".step.svg" ), format = "svg" )
206176
207- print ("plotting..." )
177+ logger . info ("plotting..." )
208178 val : xr .DataArray = stat_utils .rmsd (
209179 np .exp (xr .DataArray (sim .params ["a" ][0 ], dims = ("features" ,))),
210180 np .exp (benchmark_data .a .isel (design_loc_params = 0 )), axis = [0 ])
@@ -218,7 +188,7 @@ def plot_stat(val, val_name, name_prefix, scale_y_log10=False):
218188 val : xr .DataArray = benchmark_data .loss
219189 plot_stat (val , "loss" , "loss" )
220190
221- print ("ready" )
191+ logger . info ("ready" )
222192
223193
224194def clean (root_dir : str ):
@@ -230,7 +200,7 @@ def clean(root_dir: str):
230200 elif os .path .isdir (file_path ):
231201 shutil .rmtree (file_path )
232202 except Exception as e :
233- print (e )
203+ logger . info (e )
234204
235205
236206def main ():
@@ -293,7 +263,7 @@ def main():
293263 elif action == "print_samples" :
294264 benchmark_samples = get_benchmark_samples (root_dir )
295265 for smpl in benchmark_samples :
296- print (smpl )
266+ logger . info (smpl )
297267 elif action == "plot" :
298268 plot_benchmark (root_dir )
299269 elif action == "clean" :
0 commit comments