1010import numpy as np
1111from numpy .random import default_rng
1212
13- from cmdstanpy import _TMPDIR
14- from cmdstanpy .utils import (
15- cmdstan_path ,
16- cmdstan_version_before ,
17- create_named_text_file ,
18- get_logger ,
19- read_metric ,
20- write_stan_json ,
21- )
13+ from cmdstanpy .utils import cmdstan_path , cmdstan_version_before , get_logger
2214
2315OptionalPath = Union [str , os .PathLike , None ]
2416
@@ -65,9 +57,8 @@ def __init__(
6557 save_warmup : bool = False ,
6658 thin : Optional [int ] = None ,
6759 max_treedepth : Optional [int ] = None ,
68- metric : Union [
69- str , dict [str , Any ], list [str ], list [dict [str , Any ]], None
70- ] = None ,
60+ metric_type : Optional [str ] = None ,
61+ metric_file : Union [str , list [str ], None ] = None ,
7162 step_size : Union [float , list [float ], None ] = None ,
7263 adapt_engaged : bool = True ,
7364 adapt_delta : Optional [float ] = None ,
@@ -83,9 +74,8 @@ def __init__(
8374 self .save_warmup = save_warmup
8475 self .thin = thin
8576 self .max_treedepth = max_treedepth
86- self .metric = metric
87- self .metric_type : Optional [str ] = None
88- self .metric_file : Union [str , list [str ], None ] = None
77+ self .metric_type : Optional [str ] = metric_type
78+ self .metric_file : Union [str , list [str ], None ] = metric_file
8979 self .step_size = step_size
9080 self .adapt_engaged = adapt_engaged
9181 self .adapt_delta = adapt_delta
@@ -178,124 +168,15 @@ def validate(self, chains: Optional[int]) -> None:
178168 'Argument "step_size" must be > 0, '
179169 'chain {}, found {}.' .format (i + 1 , step_size )
180170 )
181- if self .metric is not None :
182- if isinstance (self .metric , str ):
183- if self .metric in ['diag' , 'diag_e' ]:
184- self .metric_type = 'diag_e'
185- elif self .metric in ['dense' , 'dense_e' ]:
186- self .metric_type = 'dense_e'
187- elif self .metric in ['unit' , 'unit_e' ]:
188- self .metric_type = 'unit_e'
189- else :
190- if not os .path .exists (self .metric ):
191- raise ValueError ('no such file {}' .format (self .metric ))
192- dims = read_metric (self .metric )
193- if len (dims ) == 1 :
194- self .metric_type = 'diag_e'
195- else :
196- self .metric_type = 'dense_e'
197- self .metric_file = self .metric
198- elif isinstance (self .metric , dict ):
199- if 'inv_metric' not in self .metric :
200- raise ValueError (
201- 'Entry "inv_metric" not found in metric dict.'
202- )
203- dims = list (np .asarray (self .metric ['inv_metric' ]).shape )
204- if len (dims ) == 1 :
205- self .metric_type = 'diag_e'
206- else :
207- self .metric_type = 'dense_e'
208- dict_file = create_named_text_file (
209- dir = _TMPDIR , prefix = "metric" , suffix = ".json"
210- )
211- write_stan_json (dict_file , self .metric )
212- self .metric_file = dict_file
213- elif isinstance (self .metric , (list , tuple )):
214- if len (self .metric ) != chains :
215- raise ValueError (
216- 'Number of metric files must match number of chains,'
217- ' found {} metric files for {} chains.' .format (
218- len (self .metric ), chains
219- )
220- )
221- if all (isinstance (elem , dict ) for elem in self .metric ):
222- metric_files : list [str ] = []
223- for i , metric in enumerate (self .metric ):
224- metric_dict : dict [str , Any ] = metric # type: ignore
225- if 'inv_metric' not in metric_dict :
226- raise ValueError (
227- 'Entry "inv_metric" not found in metric dict '
228- 'for chain {}.' .format (i + 1 )
229- )
230- if i == 0 :
231- dims = list (
232- np .asarray (metric_dict ['inv_metric' ]).shape
233- )
234- else :
235- dims2 = list (
236- np .asarray (metric_dict ['inv_metric' ]).shape
237- )
238- if dims != dims2 :
239- raise ValueError (
240- 'Found inconsistent "inv_metric" entry '
241- 'for chain {}: entry has dims '
242- '{}, expected {}.' .format (
243- i + 1 , dims , dims2
244- )
245- )
246- dict_file = create_named_text_file (
247- dir = _TMPDIR , prefix = "metric" , suffix = ".json"
248- )
249- write_stan_json (dict_file , metric_dict )
250- metric_files .append (dict_file )
251- if len (dims ) == 1 :
252- self .metric_type = 'diag_e'
253- else :
254- self .metric_type = 'dense_e'
255- self .metric_file = metric_files
256- elif all (isinstance (elem , str ) for elem in self .metric ):
257- metric_files = []
258- for i , metric in enumerate (self .metric ):
259- assert isinstance (metric , str ) # typecheck
260- if not os .path .exists (metric ):
261- raise ValueError ('no such file {}' .format (metric ))
262- if i == 0 :
263- dims = read_metric (metric )
264- else :
265- dims2 = read_metric (metric )
266- if len (dims ) != len (dims2 ):
267- raise ValueError (
268- 'Metrics files {}, {},'
269- ' inconsistent metrics' .format (
270- self .metric [0 ], metric
271- )
272- )
273- if dims != dims2 :
274- raise ValueError (
275- 'Metrics files {}, {},'
276- ' inconsistent metrics' .format (
277- self .metric [0 ], metric
278- )
279- )
280- metric_files .append (metric )
281- if len (dims ) == 1 :
282- self .metric_type = 'diag_e'
283- else :
284- self .metric_type = 'dense_e'
285- self .metric_file = metric_files
286- else :
287- raise ValueError (
288- 'Argument "metric" must be a list of pathnames or '
289- 'Python dicts, found list of {}.' .format (
290- type (self .metric [0 ])
291- )
292- )
293- else :
171+ if self .metric_type is not None :
172+ if self .metric_type in ['diag' , 'dense' , 'unit' ]:
173+ self .metric_type += '_e'
174+ if self .metric_type not in ['diag_e' , 'dense_e' , 'unit_e' ]:
294175 raise ValueError (
295- 'Invalid metric specified, not a recognized metric type, '
296- 'must be either a metric type name, a filepath, dict, '
297- 'or list of per-chain filepaths or dicts. Found '
298- 'an object of type {}.' . format ( type ( self . metric ) )
176+ 'Argument " metric" must be one of [diag, dense, unit, '
177+ ' diag_e, dense_e, unit_e], found {}.' . format (
178+ self . metric_type
179+ )
299180 )
300181
301182 if self .adapt_delta is not None :
@@ -332,7 +213,8 @@ def validate(self, chains: Optional[int]) -> None:
332213
333214 if self .fixed_param and (
334215 self .max_treedepth is not None
335- or self .metric is not None
216+ or self .metric_type is not None
217+ or self .metric_file is not None
336218 or self .step_size is not None
337219 or not (
338220 self .adapt_delta is None
@@ -371,7 +253,7 @@ def compose(self, idx: int, cmd: list[str]) -> list[str]:
371253 cmd .append (f'stepsize={ self .step_size } ' )
372254 else :
373255 cmd .append (f'stepsize={ self .step_size [idx ]} ' )
374- if self .metric is not None :
256+ if self .metric_type is not None :
375257 cmd .append (f'metric={ self .metric_type } ' )
376258 if self .metric_file is not None :
377259 if not isinstance (self .metric_file , list ):
0 commit comments