22
33import glob
44import os
5- from typing import Any , Dict , List , Optional , Union
5+ from typing import List , Optional , Union
66
77from cmdstanpy .cmdstan_args import (
88 CmdStanArgs ,
1212 SamplerArgs ,
1313 VariationalArgs ,
1414)
15- from cmdstanpy .utils import check_sampler_csv , get_logger , scan_config
15+ from cmdstanpy .utils import check_sampler_csv , get_logger , stancsv
1616
1717from .gq import CmdStanGQ
1818from .laplace import CmdStanLaplace
@@ -103,10 +103,9 @@ def from_csv(
103103 ' includes non-csv file: {}' .format (file )
104104 )
105105
106- config_dict : Dict [str , Any ] = {}
107106 try :
108- with open (csvfiles [0 ], 'r' ) as fd :
109- scan_config ( fd , config_dict , 0 )
107+ comments , * _ = stancsv . parse_comments_header_and_draws (csvfiles [0 ])
108+ config_dict = stancsv . parse_config ( comments )
110109 except (IOError , OSError , PermissionError ) as e :
111110 raise ValueError ('Cannot read CSV file: {}' .format (csvfiles [0 ])) from e
112111 if 'model' not in config_dict or 'method' not in config_dict :
@@ -118,39 +117,43 @@ def from_csv(
118117 method , config_dict ['method' ]
119118 )
120119 )
120+ model : str = config_dict ['model' ] # type: ignore
121121 try :
122122 if config_dict ['method' ] == 'sample' :
123+ save_warmup = config_dict ['save_warmup' ] == 1
123124 chains = len (csvfiles )
125+ num_samples : int = config_dict ['num_samples' ] # type: ignore
126+ num_warmup : int = config_dict ['num_warmup' ] # type: ignore
127+ thin : int = config_dict ['thin' ] # type: ignore
124128 sampler_args = SamplerArgs (
125- iter_sampling = config_dict [ ' num_samples' ] ,
126- iter_warmup = config_dict [ ' num_warmup' ] ,
127- thin = config_dict [ ' thin' ] ,
128- save_warmup = config_dict [ ' save_warmup' ] ,
129+ iter_sampling = num_samples ,
130+ iter_warmup = num_warmup ,
131+ thin = thin ,
132+ save_warmup = save_warmup ,
129133 )
130134 # bugfix 425, check for fixed_params output
131135 try :
132136 check_sampler_csv (
133137 csvfiles [0 ],
134- iter_sampling = config_dict [ ' num_samples' ] ,
135- iter_warmup = config_dict [ ' num_warmup' ] ,
136- thin = config_dict [ ' thin' ] ,
137- save_warmup = config_dict [ ' save_warmup' ] ,
138+ iter_sampling = num_samples ,
139+ iter_warmup = num_warmup ,
140+ thin = thin ,
141+ save_warmup = save_warmup ,
138142 )
139143 except ValueError :
140144 try :
141145 check_sampler_csv (
142146 csvfiles [0 ],
143- is_fixed_param = True ,
144- iter_sampling = config_dict ['num_samples' ],
145- iter_warmup = config_dict ['num_warmup' ],
146- thin = config_dict ['thin' ],
147- save_warmup = config_dict ['save_warmup' ],
147+ iter_sampling = num_samples ,
148+ iter_warmup = num_warmup ,
149+ thin = thin ,
150+ save_warmup = save_warmup ,
148151 )
149152 sampler_args = SamplerArgs (
150- iter_sampling = config_dict [ ' num_samples' ] ,
151- iter_warmup = config_dict [ ' num_warmup' ] ,
152- thin = config_dict [ ' thin' ] ,
153- save_warmup = config_dict [ ' save_warmup' ] ,
153+ iter_sampling = num_samples ,
154+ iter_warmup = num_warmup ,
155+ thin = thin ,
156+ save_warmup = save_warmup ,
154157 fixed_param = True ,
155158 )
156159 except ValueError as e :
@@ -159,8 +162,8 @@ def from_csv(
159162 ) from e
160163
161164 cmdstan_args = CmdStanArgs (
162- model_name = config_dict [ ' model' ] ,
163- model_exe = config_dict [ ' model' ] ,
165+ model_name = model ,
166+ model_exe = model ,
164167 chain_ids = [x + 1 for x in range (chains )],
165168 method_args = sampler_args ,
166169 )
@@ -177,14 +180,18 @@ def from_csv(
177180 "Cannot find optimization algorithm"
178181 " in file {}." .format (csvfiles [0 ])
179182 )
183+ algorithm : str = config_dict ['algorithm' ] # type: ignore
184+ save_iterations = config_dict ['save_iterations' ] == 1
185+ jacobian = config_dict .get ('jacobian' , 0 ) == 1
186+
180187 optimize_args = OptimizeArgs (
181- algorithm = config_dict [ ' algorithm' ] ,
182- save_iterations = config_dict [ ' save_iterations' ] ,
183- jacobian = config_dict . get ( ' jacobian' , 0 ) ,
188+ algorithm = algorithm ,
189+ save_iterations = save_iterations ,
190+ jacobian = jacobian ,
184191 )
185192 cmdstan_args = CmdStanArgs (
186- model_name = config_dict [ ' model' ] ,
187- model_exe = config_dict [ ' model' ] ,
193+ model_name = model ,
194+ model_exe = model ,
188195 chain_ids = None ,
189196 method_args = optimize_args ,
190197 )
@@ -200,18 +207,18 @@ def from_csv(
200207 " in file {}." .format (csvfiles [0 ])
201208 )
202209 variational_args = VariationalArgs (
203- algorithm = config_dict ['algorithm' ],
204- iter = config_dict ['iter' ],
205- grad_samples = config_dict ['grad_samples' ],
206- elbo_samples = config_dict ['elbo_samples' ],
207- eta = config_dict ['eta' ],
208- tol_rel_obj = config_dict ['tol_rel_obj' ],
209- eval_elbo = config_dict ['eval_elbo' ],
210- output_samples = config_dict ['output_samples' ],
210+ algorithm = config_dict ['algorithm' ], # type: ignore
211+ iter = config_dict ['iter' ], # type: ignore
212+ grad_samples = config_dict ['grad_samples' ], # type: ignore
213+ elbo_samples = config_dict ['elbo_samples' ], # type: ignore
214+ eta = config_dict ['eta' ], # type: ignore
215+ tol_rel_obj = config_dict ['tol_rel_obj' ], # type: ignore
216+ eval_elbo = config_dict ['eval_elbo' ], # type: ignore
217+ output_samples = config_dict ['output_samples' ], # type: ignore
211218 )
212219 cmdstan_args = CmdStanArgs (
213- model_name = config_dict [ ' model' ] ,
214- model_exe = config_dict [ ' model' ] ,
220+ model_name = model ,
221+ model_exe = model ,
215222 chain_ids = None ,
216223 method_args = variational_args ,
217224 )
@@ -221,14 +228,15 @@ def from_csv(
221228 runset ._set_retcode (i , 0 )
222229 return CmdStanVB (runset )
223230 elif config_dict ['method' ] == 'laplace' :
231+ jacobian = config_dict ['jacobian' ] == 1
224232 laplace_args = LaplaceArgs (
225- mode = config_dict ['mode' ],
226- draws = config_dict ['draws' ],
227- jacobian = config_dict [ ' jacobian' ] ,
233+ mode = config_dict ['mode' ], # type: ignore
234+ draws = config_dict ['draws' ], # type: ignore
235+ jacobian = jacobian ,
228236 )
229237 cmdstan_args = CmdStanArgs (
230- model_name = config_dict [ ' model' ] ,
231- model_exe = config_dict [ ' model' ] ,
238+ model_name = model ,
239+ model_exe = model ,
232240 chain_ids = None ,
233241 method_args = laplace_args ,
234242 )
@@ -237,18 +245,18 @@ def from_csv(
237245 for i in range (len (runset ._retcodes )):
238246 runset ._set_retcode (i , 0 )
239247 mode : CmdStanMLE = from_csv (
240- config_dict ['mode' ],
248+ config_dict ['mode' ], # type: ignore
241249 method = 'optimize' ,
242250 ) # type: ignore
243251 return CmdStanLaplace (runset , mode = mode )
244252 elif config_dict ['method' ] == 'pathfinder' :
245253 pathfinder_args = PathfinderArgs (
246- num_draws = config_dict ['num_draws' ],
247- num_paths = config_dict ['num_paths' ],
254+ num_draws = config_dict ['num_draws' ], # type: ignore
255+ num_paths = config_dict ['num_paths' ], # type: ignore
248256 )
249257 cmdstan_args = CmdStanArgs (
250- model_name = config_dict [ ' model' ] ,
251- model_exe = config_dict [ ' model' ] ,
258+ model_name = model ,
259+ model_exe = model ,
252260 chain_ids = None ,
253261 method_args = pathfinder_args ,
254262 )
0 commit comments