11"""
22CmdStan arguments
33"""
4+
45import os
56from enum import Enum , auto
67from time import time
7- from typing import Any , Dict , List , Mapping , Optional , Union
8+ from typing import Any , List , Mapping , Optional , Union
89
910import numpy as np
1011from numpy .random import default_rng
1112
12- from cmdstanpy import _TMPDIR
13- from cmdstanpy .utils import (
14- cmdstan_path ,
15- cmdstan_version_before ,
16- create_named_text_file ,
17- get_logger ,
18- read_metric ,
19- write_stan_json ,
20- )
13+ from cmdstanpy .utils import cmdstan_path , cmdstan_version_before , get_logger
2114
2215OptionalPath = Union [str , os .PathLike , None ]
2316
@@ -64,9 +57,8 @@ def __init__(
6457 save_warmup : bool = False ,
6558 thin : Optional [int ] = None ,
6659 max_treedepth : Optional [int ] = None ,
67- metric : Union [
68- str , Dict [str , Any ], List [str ], List [Dict [str , Any ]], None
69- ] = None ,
60+ metric_type : Optional [str ] = None ,
61+ metric_file : Union [str , List [str ], None ] = None ,
7062 step_size : Union [float , List [float ], None ] = None ,
7163 adapt_engaged : bool = True ,
7264 adapt_delta : Optional [float ] = None ,
@@ -82,9 +74,8 @@ def __init__(
8274 self .save_warmup = save_warmup
8375 self .thin = thin
8476 self .max_treedepth = max_treedepth
85- self .metric = metric
86- self .metric_type : Optional [str ] = None
87- self .metric_file : Union [str , List [str ], None ] = None
77+ self .metric_type : Optional [str ] = metric_type
78+ self .metric_file : Union [str , List [str ], None ] = metric_file
8879 self .step_size = step_size
8980 self .adapt_engaged = adapt_engaged
9081 self .adapt_delta = adapt_delta
@@ -176,124 +167,15 @@ def validate(self, chains: Optional[int]) -> None:
176167 'Argument "step_size" must be > 0, '
177168 'chain {}, found {}.' .format (i + 1 , step_size )
178169 )
179- if self .metric is not None :
180- if isinstance (self .metric , str ):
181- if self .metric in ['diag' , 'diag_e' ]:
182- self .metric_type = 'diag_e'
183- elif self .metric in ['dense' , 'dense_e' ]:
184- self .metric_type = 'dense_e'
185- elif self .metric in ['unit' , 'unit_e' ]:
186- self .metric_type = 'unit_e'
187- else :
188- if not os .path .exists (self .metric ):
189- raise ValueError ('no such file {}' .format (self .metric ))
190- dims = read_metric (self .metric )
191- if len (dims ) == 1 :
192- self .metric_type = 'diag_e'
193- else :
194- self .metric_type = 'dense_e'
195- self .metric_file = self .metric
196- elif isinstance (self .metric , dict ):
197- if 'inv_metric' not in self .metric :
198- raise ValueError (
199- 'Entry "inv_metric" not found in metric dict.'
200- )
201- dims = list (np .asarray (self .metric ['inv_metric' ]).shape )
202- if len (dims ) == 1 :
203- self .metric_type = 'diag_e'
204- else :
205- self .metric_type = 'dense_e'
206- dict_file = create_named_text_file (
207- dir = _TMPDIR , prefix = "metric" , suffix = ".json"
208- )
209- write_stan_json (dict_file , self .metric )
210- self .metric_file = dict_file
211- elif isinstance (self .metric , (list , tuple )):
212- if len (self .metric ) != chains :
213- raise ValueError (
214- 'Number of metric files must match number of chains,'
215- ' found {} metric files for {} chains.' .format (
216- len (self .metric ), chains
217- )
218- )
219- if all (isinstance (elem , dict ) for elem in self .metric ):
220- metric_files : List [str ] = []
221- for i , metric in enumerate (self .metric ):
222- metric_dict : Dict [str , Any ] = metric # type: ignore
223- if 'inv_metric' not in metric_dict :
224- raise ValueError (
225- 'Entry "inv_metric" not found in metric dict '
226- 'for chain {}.' .format (i + 1 )
227- )
228- if i == 0 :
229- dims = list (
230- np .asarray (metric_dict ['inv_metric' ]).shape
231- )
232- else :
233- dims2 = list (
234- np .asarray (metric_dict ['inv_metric' ]).shape
235- )
236- if dims != dims2 :
237- raise ValueError (
238- 'Found inconsistent "inv_metric" entry '
239- 'for chain {}: entry has dims '
240- '{}, expected {}.' .format (
241- i + 1 , dims , dims2
242- )
243- )
244- dict_file = create_named_text_file (
245- dir = _TMPDIR , prefix = "metric" , suffix = ".json"
246- )
247- write_stan_json (dict_file , metric_dict )
248- metric_files .append (dict_file )
249- if len (dims ) == 1 :
250- self .metric_type = 'diag_e'
251- else :
252- self .metric_type = 'dense_e'
253- self .metric_file = metric_files
254- elif all (isinstance (elem , str ) for elem in self .metric ):
255- metric_files = []
256- for i , metric in enumerate (self .metric ):
257- assert isinstance (metric , str ) # typecheck
258- if not os .path .exists (metric ):
259- raise ValueError ('no such file {}' .format (metric ))
260- if i == 0 :
261- dims = read_metric (metric )
262- else :
263- dims2 = read_metric (metric )
264- if len (dims ) != len (dims2 ):
265- raise ValueError (
266- 'Metrics files {}, {},'
267- ' inconsistent metrics' .format (
268- self .metric [0 ], metric
269- )
270- )
271- if dims != dims2 :
272- raise ValueError (
273- 'Metrics files {}, {},'
274- ' inconsistent metrics' .format (
275- self .metric [0 ], metric
276- )
277- )
278- metric_files .append (metric )
279- if len (dims ) == 1 :
280- self .metric_type = 'diag_e'
281- else :
282- self .metric_type = 'dense_e'
283- self .metric_file = metric_files
284- else :
285- raise ValueError (
286- 'Argument "metric" must be a list of pathnames or '
287- 'Python dicts, found list of {}.' .format (
288- type (self .metric [0 ])
289- )
290- )
291- else :
170+ if self .metric_type is not None :
171+ if self .metric_type in ['diag' , 'dense' , 'unit' ]:
172+ self .metric_type += '_e'
173+ if self .metric_type not in ['diag_e' , 'dense_e' , 'unit_e' ]:
292174 raise ValueError (
293- 'Invalid metric specified, not a recognized metric type, '
294- 'must be either a metric type name, a filepath, dict, '
295- 'or list of per-chain filepaths or dicts. Found '
296- 'an object of type {}.' . format ( type ( self . metric ) )
175+ 'Argument " metric" must be one of [diag, dense, unit, '
176+ ' diag_e, dense_e, unit_e], found {}.' . format (
177+ self . metric_type
178+ )
297179 )
298180
299181 if self .adapt_delta is not None :
@@ -330,7 +212,8 @@ def validate(self, chains: Optional[int]) -> None:
330212
331213 if self .fixed_param and (
332214 self .max_treedepth is not None
333- or self .metric is not None
215+ or self .metric_type is not None
216+ or self .metric_file is not None
334217 or self .step_size is not None
335218 or not (
336219 self .adapt_delta is None
@@ -369,7 +252,7 @@ def compose(self, idx: int, cmd: List[str]) -> List[str]:
369252 cmd .append (f'stepsize={ self .step_size } ' )
370253 else :
371254 cmd .append (f'stepsize={ self .step_size [idx ]} ' )
372- if self .metric is not None :
255+ if self .metric_type is not None :
373256 cmd .append (f'metric={ self .metric_type } ' )
374257 if self .metric_file is not None :
375258 if not isinstance (self .metric_file , list ):
0 commit comments