22Container for the result of running the sample (MCMC) method
33"""
44
5+ from __future__ import annotations
6+
57import math
68import os
79from io import StringIO
3133 stancsv ,
3234)
3335
34- from .metadata import InferenceMetadata
36+ from .metadata import InferenceMetadata , MetricInfo
3537from .runset import RunSet
3638
3739
@@ -81,6 +83,7 @@ def __init__(
8183 # info from CSV values, instantiated lazily
8284 self ._draws : np .ndarray = np .array (())
8385 # only valid when not is_fixed_param
86+ self ._metric_type : str | None = None
8487 self ._metric : np .ndarray = np .array (())
8588 self ._step_size : np .ndarray = np .array (())
8689 self ._divergences : np .ndarray = np .zeros (self .runset .chains , dtype = int )
@@ -92,6 +95,8 @@ def __init__(
9295 # info from CSV header and initial and final comment blocks
9396 config = self ._validate_csv_files ()
9497 self ._metadata : InferenceMetadata = InferenceMetadata (config )
98+ self ._chain_metric_info : list [MetricInfo ] = []
99+
95100 if not self ._is_fixed_param :
96101 self ._check_sampler_diagnostics ()
97102
@@ -216,11 +221,13 @@ def metric_type(self) -> str | None:
216221 to CmdStan arg 'metric'.
217222 When sampler algorithm 'fixed_param' is specified, metric_type is None.
218223 """
219- return (
220- self ._metadata .cmdstan_config ['metric' ]
221- if not self ._is_fixed_param
222- else None
223- )
224+ if self ._is_fixed_param :
225+ return None
226+
227+ if self ._metric_type is None :
228+ self ._parse_metric_info ()
229+
230+ return self ._metric_type
224231
225232 @property
226233 def inv_metric (self ) -> np .ndarray | None :
@@ -230,10 +237,15 @@ def inv_metric(self) -> np.ndarray | None:
230237 a ``nchains x nparams x nparams`` array when metric_type is 'dense_e',
231238 or ``None`` when metric_type is 'unit_e' or algorithm is 'fixed_param'.
232239 """
233- if self ._is_fixed_param or self .metric_type == 'unit_e' :
240+ if self ._is_fixed_param :
241+ return None
242+
243+ if self ._metric_type is None :
244+ self ._parse_metric_info ()
245+
246+ if self .metric_type == 'unit_e' :
234247 return None
235248
236- self ._assemble_draws ()
237249 return self ._metric
238250
239251 @property
@@ -242,8 +254,13 @@ def step_size(self) -> np.ndarray | None:
242254 Step size used by sampler for each chain.
243255 When sampler algorithm 'fixed_param' is specified, step size is None.
244256 """
245- self ._assemble_draws ()
246- return self ._step_size if not self ._is_fixed_param else None
257+ if self ._is_fixed_param :
258+ return None
259+
260+ if self ._metric_type is None :
261+ self ._parse_metric_info ()
262+
263+ return self ._step_size
247264
248265 @property
249266 def thin (self ) -> int :
@@ -382,6 +399,27 @@ def _validate_csv_files(self) -> dict[str, Any]:
382399 self ._max_treedepths [i ] = drest ['ct_max_treedepth' ]
383400 return dzero
384401
402+ def _parse_metric_info (self ) -> None :
403+ """Extracts metric type, inv_metric, and step size information from the
404+ parsed metric JSONs."""
405+ self ._chain_metric_info = []
406+ for mf in self .runset .metric_files :
407+ with open (mf ) as f :
408+ self ._chain_metric_info .append (
409+ MetricInfo .model_validate_json (f .read ())
410+ )
411+
412+ metric_types = {cmi .metric_type for cmi in self ._chain_metric_info }
413+ if len (metric_types ) != 1 :
414+ raise ValueError ("Inconsistent metric types found across chains" )
415+ self ._metric_type = self ._chain_metric_info [0 ].metric_type
416+ self ._metric = np .asarray (
417+ [cmi .inv_metric for cmi in self ._chain_metric_info ]
418+ )
419+ self ._step_size = np .asarray (
420+ [cmi .stepsize for cmi in self ._chain_metric_info ]
421+ )
422+
385423 def _check_sampler_diagnostics (self ) -> None :
386424 """
387425 Warn if any iterations ended in divergences or hit maxtreedepth.
@@ -424,13 +462,11 @@ def _assemble_draws(self) -> None:
424462 dtype = np .float64 ,
425463 order = 'F' ,
426464 )
427- self ._step_size = np .empty (self .chains , dtype = np .float64 )
428465
429- mass_matrix_per_chain = []
430466 for chain in range (self .chains ):
431467 try :
432468 (
433- comments ,
469+ _ ,
434470 header ,
435471 draws ,
436472 ) = stancsv .parse_comments_header_and_draws (
@@ -443,20 +479,11 @@ def _assemble_draws(self) -> None:
443479 draws_np = np .empty ((0 , n_cols ))
444480
445481 self ._draws [:, chain , :] = draws_np
446- if not self ._is_fixed_param :
447- (
448- self ._step_size [chain ],
449- mass_matrix ,
450- ) = stancsv .parse_hmc_adaptation_lines (comments )
451- mass_matrix_per_chain .append (mass_matrix )
452482 except Exception as exc :
453483 raise ValueError (
454484 f"Parsing output from { self .runset .csv_files [chain ]} failed"
455485 ) from exc
456486
457- if all (mm is not None for mm in mass_matrix_per_chain ):
458- self ._metric = np .array (mass_matrix_per_chain )
459-
460487 assert self ._draws is not None
461488
462489 def summary (
@@ -652,7 +679,7 @@ def draws_pd(
652679
653680 def draws_xr (
654681 self , vars : str | list [str ] | None = None , inc_warmup : bool = False
655- ) -> " xr.Dataset" :
682+ ) -> xr .Dataset :
656683 """
657684 Returns the sampler draws as a xarray Dataset.
658685
0 commit comments