@@ -117,42 +117,42 @@ def from_csv(
117117 method , config_dict ['method' ]
118118 )
119119 )
120+ model : str = config_dict ['model' ] # type: ignore
120121 try :
121122 if config_dict ['method' ] == 'sample' :
122- assert isinstance (config_dict ['num_samples' ], int )
123- assert isinstance (config_dict ['num_warmup' ], int )
124- assert isinstance (config_dict ['thin' ], int )
125- assert isinstance (config_dict ['model' ], str )
126123 save_warmup = config_dict ['save_warmup' ] == 1
127124 chains = len (csvfiles )
125+ num_samples : int = config_dict ['num_samples' ] # type: ignore
126+ num_warmup : int = config_dict ['num_warmup' ] # type: ignore
127+ thin : int = config_dict ['thin' ] # type: ignore
128128 sampler_args = SamplerArgs (
129- iter_sampling = config_dict [ ' num_samples' ] ,
130- iter_warmup = config_dict [ ' num_warmup' ] ,
131- thin = config_dict [ ' thin' ] ,
129+ iter_sampling = num_samples ,
130+ iter_warmup = num_warmup ,
131+ thin = thin ,
132132 save_warmup = save_warmup ,
133133 )
134134 # bugfix 425, check for fixed_params output
135135 try :
136136 check_sampler_csv (
137137 csvfiles [0 ],
138- iter_sampling = config_dict [ ' num_samples' ] ,
139- iter_warmup = config_dict [ ' num_warmup' ] ,
140- thin = config_dict [ ' thin' ] ,
138+ iter_sampling = num_samples ,
139+ iter_warmup = num_warmup ,
140+ thin = thin ,
141141 save_warmup = save_warmup ,
142142 )
143143 except ValueError :
144144 try :
145145 check_sampler_csv (
146146 csvfiles [0 ],
147- iter_sampling = config_dict [ ' num_samples' ] ,
148- iter_warmup = config_dict [ ' num_warmup' ] ,
149- thin = config_dict [ ' thin' ] ,
147+ iter_sampling = num_samples ,
148+ iter_warmup = num_warmup ,
149+ thin = thin ,
150150 save_warmup = save_warmup ,
151151 )
152152 sampler_args = SamplerArgs (
153- iter_sampling = config_dict [ ' num_samples' ] ,
154- iter_warmup = config_dict [ ' num_warmup' ] ,
155- thin = config_dict [ ' thin' ] ,
153+ iter_sampling = num_samples ,
154+ iter_warmup = num_warmup ,
155+ thin = thin ,
156156 save_warmup = save_warmup ,
157157 fixed_param = True ,
158158 )
@@ -162,8 +162,8 @@ def from_csv(
162162 ) from e
163163
164164 cmdstan_args = CmdStanArgs (
165- model_name = config_dict [ ' model' ] ,
166- model_exe = config_dict [ ' model' ] ,
165+ model_name = model ,
166+ model_exe = model ,
167167 chain_ids = [x + 1 for x in range (chains )],
168168 method_args = sampler_args ,
169169 )
@@ -180,19 +180,18 @@ def from_csv(
180180 "Cannot find optimization algorithm"
181181 " in file {}." .format (csvfiles [0 ])
182182 )
183- assert isinstance (config_dict ['algorithm' ], str )
184- assert isinstance (config_dict ['model' ], str )
183+ algorithm : str = config_dict ['algorithm' ] # type: ignore
185184 save_iterations = config_dict ['save_iterations' ] == 1
186185 jacobian = config_dict .get ('jacobian' , 0 ) == 1
187186
188187 optimize_args = OptimizeArgs (
189- algorithm = config_dict [ ' algorithm' ] ,
188+ algorithm = algorithm ,
190189 save_iterations = save_iterations ,
191190 jacobian = jacobian ,
192191 )
193192 cmdstan_args = CmdStanArgs (
194- model_name = config_dict [ ' model' ] ,
195- model_exe = config_dict [ ' model' ] ,
193+ model_name = model ,
194+ model_exe = model ,
196195 chain_ids = None ,
197196 method_args = optimize_args ,
198197 )
@@ -207,28 +206,19 @@ def from_csv(
207206 "Cannot find variational algorithm"
208207 " in file {}." .format (csvfiles [0 ])
209208 )
210- assert isinstance (config_dict ['model' ], str )
211- assert isinstance (config_dict ['algorithm' ], str )
212- assert isinstance (config_dict ['iter' ], int )
213- assert isinstance (config_dict ['grad_samples' ], int )
214- assert isinstance (config_dict ['elbo_samples' ], int )
215- assert isinstance (config_dict ['eta' ], (int , float ))
216- assert isinstance (config_dict ['tol_rel_obj' ], float )
217- assert isinstance (config_dict ['eval_elbo' ], int )
218- assert isinstance (config_dict ['output_samples' ], int )
219209 variational_args = VariationalArgs (
220- algorithm = config_dict ['algorithm' ],
221- iter = config_dict ['iter' ],
222- grad_samples = config_dict ['grad_samples' ],
223- elbo_samples = config_dict ['elbo_samples' ],
224- eta = config_dict ['eta' ],
225- tol_rel_obj = config_dict ['tol_rel_obj' ],
226- eval_elbo = config_dict ['eval_elbo' ],
227- output_samples = config_dict ['output_samples' ],
210+ algorithm = config_dict ['algorithm' ], # type: ignore
211+ iter = config_dict ['iter' ], # type: ignore
212+ grad_samples = config_dict ['grad_samples' ], # type: ignore
213+ elbo_samples = config_dict ['elbo_samples' ], # type: ignore
214+ eta = config_dict ['eta' ], # type: ignore
215+ tol_rel_obj = config_dict ['tol_rel_obj' ], # type: ignore
216+ eval_elbo = config_dict ['eval_elbo' ], # type: ignore
217+ output_samples = config_dict ['output_samples' ], # type: ignore
228218 )
229219 cmdstan_args = CmdStanArgs (
230- model_name = config_dict [ ' model' ] ,
231- model_exe = config_dict [ ' model' ] ,
220+ model_name = model ,
221+ model_exe = model ,
232222 chain_ids = None ,
233223 method_args = variational_args ,
234224 )
@@ -238,18 +228,15 @@ def from_csv(
238228 runset ._set_retcode (i , 0 )
239229 return CmdStanVB (runset )
240230 elif config_dict ['method' ] == 'laplace' :
241- assert isinstance (config_dict ['mode' ], str )
242- assert isinstance (config_dict ['draws' ], int )
243- assert isinstance (config_dict ['model' ], str )
244231 jacobian = config_dict ['jacobian' ] == 1
245232 laplace_args = LaplaceArgs (
246- mode = config_dict ['mode' ],
247- draws = config_dict ['draws' ],
233+ mode = config_dict ['mode' ], # type: ignore
234+ draws = config_dict ['draws' ], # type: ignore
248235 jacobian = jacobian ,
249236 )
250237 cmdstan_args = CmdStanArgs (
251- model_name = config_dict [ ' model' ] ,
252- model_exe = config_dict [ ' model' ] ,
238+ model_name = model ,
239+ model_exe = model ,
253240 chain_ids = None ,
254241 method_args = laplace_args ,
255242 )
@@ -258,21 +245,18 @@ def from_csv(
258245 for i in range (len (runset ._retcodes )):
259246 runset ._set_retcode (i , 0 )
260247 mode : CmdStanMLE = from_csv (
261- config_dict ['mode' ],
248+ config_dict ['mode' ], # type: ignore
262249 method = 'optimize' ,
263250 ) # type: ignore
264251 return CmdStanLaplace (runset , mode = mode )
265252 elif config_dict ['method' ] == 'pathfinder' :
266- assert isinstance (config_dict ['num_draws' ], int )
267- assert isinstance (config_dict ['num_paths' ], int )
268- assert isinstance (config_dict ['model' ], str )
269253 pathfinder_args = PathfinderArgs (
270- num_draws = config_dict ['num_draws' ],
271- num_paths = config_dict ['num_paths' ],
254+ num_draws = config_dict ['num_draws' ], # type: ignore
255+ num_paths = config_dict ['num_paths' ], # type: ignore
272256 )
273257 cmdstan_args = CmdStanArgs (
274- model_name = config_dict [ ' model' ] ,
275- model_exe = config_dict [ ' model' ] ,
258+ model_name = model ,
259+ model_exe = model ,
276260 chain_ids = None ,
277261 method_args = pathfinder_args ,
278262 )
0 commit comments