Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
da33edc
Update pathfinder _assemble_draws
amas0 Aug 8, 2025
4b268e6
Remove unnecessary comment
amas0 Aug 8, 2025
ea67bc8
Update laplace _assemble_draws
amas0 Aug 8, 2025
19a230f
Add new config parsing function
amas0 Aug 8, 2025
18e803f
Add new header extraction functions
amas0 Aug 9, 2025
b46aa07
Update pathfinder to new stancsv parsing
amas0 Aug 9, 2025
42edd71
Update laplace csv parsing logic
amas0 Aug 9, 2025
d93b709
Update gq draws parsing to optimized version
amas0 Aug 9, 2025
b32e0db
Change parse_header output to tuple[str, ...]
amas0 Aug 9, 2025
8344ee7
Implement InferenceMetadata.from_csv for common usage
amas0 Aug 9, 2025
0b41f88
Add InferenceMetadata.__getitem__ for accessing config dict
amas0 Aug 9, 2025
f3c4339
Update CmdStanGQ with new csv parsing functions
amas0 Aug 9, 2025
7f9f4e5
Add InferenceMetadata.column_names property
amas0 Aug 9, 2025
3b90dfb
Remove scan_generic_csv
amas0 Aug 9, 2025
54716c4
Update mle to new stancsv parsing
amas0 Aug 9, 2025
5a61d27
Remove scan_optimize_csv
amas0 Aug 9, 2025
5fc4c10
Use InferenceMetadata.column_name property throughout
amas0 Aug 9, 2025
df91ac8
Add helper to extract key = val pairs from stancsv
amas0 Aug 9, 2025
b927c1b
Update VB to new stancsv methods
amas0 Aug 9, 2025
f389d83
Remove scan_variational_csv
amas0 Aug 9, 2025
ff2a3f5
Use stancsv namespace for consistency
amas0 Aug 9, 2025
964b49d
Add extraction of divergences/max treedepth function
amas0 Aug 14, 2025
89a7b98
Add count function for warmup and sampling draws
amas0 Aug 14, 2025
40c4ee8
Add timing line parsing
amas0 Aug 14, 2025
761c7de
Add new metadata parsing function from sample csv
amas0 Aug 15, 2025
e4b9718
Update check_sampler_csv to use new functions
amas0 Aug 15, 2025
a3037d4
Remove old scan_ parsing functions
amas0 Aug 15, 2025
3b5a7fd
Update type assertion for variation eta in config
amas0 Aug 15, 2025
6023434
Add check to raise exception on invalid draws shape
amas0 Aug 15, 2025
2162c2a
Fixup comments/docstring
amas0 Aug 15, 2025
497d15a
Accommodate automatic fixed_param sampling
amas0 Aug 15, 2025
06bde2b
Add adaptation block validation for csv
amas0 Aug 16, 2025
670a658
Fix incorrect eta being parsed
amas0 Aug 16, 2025
ff2cb53
Remove unreacheable exception handling
amas0 Aug 17, 2025
66b5a24
Add column filter tests
amas0 Aug 17, 2025
4af87ab
Remove errant print statement
amas0 Aug 17, 2025
5c792aa
Fix incorrect missing step size check
amas0 Aug 17, 2025
6b1736c
Add tests for new parsing functions
amas0 Aug 17, 2025
a5e48c1
Remove excessive asserts-as-type-validation
amas0 Aug 19, 2025
ca961c0
Remove unnecessary num chains check
amas0 Aug 19, 2025
27159c4
Re-add NoDataError handling
amas0 Aug 19, 2025
a43896f
Refactor parsing to extract header separately
amas0 Aug 20, 2025
8c2579e
Fixup docstrings
amas0 Aug 20, 2025
3a9db21
Add TODO to remove 'is_sneaky_fixed_param' in future
amas0 Aug 21, 2025
47e5810
Re-raise stancsv parsing failures to identify file
amas0 Aug 21, 2025
f8819db
Specify chain index in gq error
amas0 Aug 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 57 additions & 49 deletions cmdstanpy/stanfit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import glob
import os
from typing import Any, Dict, List, Optional, Union
from typing import List, Optional, Union

from cmdstanpy.cmdstan_args import (
CmdStanArgs,
Expand All @@ -12,7 +12,7 @@
SamplerArgs,
VariationalArgs,
)
from cmdstanpy.utils import check_sampler_csv, get_logger, scan_config
from cmdstanpy.utils import check_sampler_csv, get_logger, stancsv

from .gq import CmdStanGQ
from .laplace import CmdStanLaplace
Expand Down Expand Up @@ -103,10 +103,9 @@ def from_csv(
' includes non-csv file: {}'.format(file)
)

config_dict: Dict[str, Any] = {}
try:
with open(csvfiles[0], 'r') as fd:
scan_config(fd, config_dict, 0)
comments, *_ = stancsv.parse_comments_header_and_draws(csvfiles[0])
config_dict = stancsv.parse_config(comments)
except (IOError, OSError, PermissionError) as e:
raise ValueError('Cannot read CSV file: {}'.format(csvfiles[0])) from e
if 'model' not in config_dict or 'method' not in config_dict:
Expand All @@ -118,39 +117,43 @@ def from_csv(
method, config_dict['method']
)
)
model: str = config_dict['model'] # type: ignore
try:
if config_dict['method'] == 'sample':
save_warmup = config_dict['save_warmup'] == 1
chains = len(csvfiles)
num_samples: int = config_dict['num_samples'] # type: ignore
num_warmup: int = config_dict['num_warmup'] # type: ignore
thin: int = config_dict['thin'] # type: ignore
sampler_args = SamplerArgs(
iter_sampling=config_dict['num_samples'],
iter_warmup=config_dict['num_warmup'],
thin=config_dict['thin'],
save_warmup=config_dict['save_warmup'],
iter_sampling=num_samples,
iter_warmup=num_warmup,
thin=thin,
save_warmup=save_warmup,
)
# bugfix 425, check for fixed_params output
try:
check_sampler_csv(
csvfiles[0],
iter_sampling=config_dict['num_samples'],
iter_warmup=config_dict['num_warmup'],
thin=config_dict['thin'],
save_warmup=config_dict['save_warmup'],
iter_sampling=num_samples,
iter_warmup=num_warmup,
thin=thin,
save_warmup=save_warmup,
)
except ValueError:
try:
check_sampler_csv(
csvfiles[0],
is_fixed_param=True,
iter_sampling=config_dict['num_samples'],
iter_warmup=config_dict['num_warmup'],
thin=config_dict['thin'],
save_warmup=config_dict['save_warmup'],
iter_sampling=num_samples,
iter_warmup=num_warmup,
thin=thin,
save_warmup=save_warmup,
)
sampler_args = SamplerArgs(
iter_sampling=config_dict['num_samples'],
iter_warmup=config_dict['num_warmup'],
thin=config_dict['thin'],
save_warmup=config_dict['save_warmup'],
iter_sampling=num_samples,
iter_warmup=num_warmup,
thin=thin,
save_warmup=save_warmup,
fixed_param=True,
)
except ValueError as e:
Expand All @@ -159,8 +162,8 @@ def from_csv(
) from e

cmdstan_args = CmdStanArgs(
model_name=config_dict['model'],
model_exe=config_dict['model'],
model_name=model,
model_exe=model,
chain_ids=[x + 1 for x in range(chains)],
method_args=sampler_args,
)
Expand All @@ -177,14 +180,18 @@ def from_csv(
"Cannot find optimization algorithm"
" in file {}.".format(csvfiles[0])
)
algorithm: str = config_dict['algorithm'] # type: ignore
save_iterations = config_dict['save_iterations'] == 1
jacobian = config_dict.get('jacobian', 0) == 1

optimize_args = OptimizeArgs(
algorithm=config_dict['algorithm'],
save_iterations=config_dict['save_iterations'],
jacobian=config_dict.get('jacobian', 0),
algorithm=algorithm,
save_iterations=save_iterations,
jacobian=jacobian,
)
cmdstan_args = CmdStanArgs(
model_name=config_dict['model'],
model_exe=config_dict['model'],
model_name=model,
model_exe=model,
chain_ids=None,
method_args=optimize_args,
)
Expand All @@ -200,18 +207,18 @@ def from_csv(
" in file {}.".format(csvfiles[0])
)
variational_args = VariationalArgs(
algorithm=config_dict['algorithm'],
iter=config_dict['iter'],
grad_samples=config_dict['grad_samples'],
elbo_samples=config_dict['elbo_samples'],
eta=config_dict['eta'],
tol_rel_obj=config_dict['tol_rel_obj'],
eval_elbo=config_dict['eval_elbo'],
output_samples=config_dict['output_samples'],
algorithm=config_dict['algorithm'], # type: ignore
iter=config_dict['iter'], # type: ignore
grad_samples=config_dict['grad_samples'], # type: ignore
elbo_samples=config_dict['elbo_samples'], # type: ignore
eta=config_dict['eta'], # type: ignore
tol_rel_obj=config_dict['tol_rel_obj'], # type: ignore
eval_elbo=config_dict['eval_elbo'], # type: ignore
output_samples=config_dict['output_samples'], # type: ignore
)
cmdstan_args = CmdStanArgs(
model_name=config_dict['model'],
model_exe=config_dict['model'],
model_name=model,
model_exe=model,
chain_ids=None,
method_args=variational_args,
)
Expand All @@ -221,14 +228,15 @@ def from_csv(
runset._set_retcode(i, 0)
return CmdStanVB(runset)
elif config_dict['method'] == 'laplace':
jacobian = config_dict['jacobian'] == 1
laplace_args = LaplaceArgs(
mode=config_dict['mode'],
draws=config_dict['draws'],
jacobian=config_dict['jacobian'],
mode=config_dict['mode'], # type: ignore
draws=config_dict['draws'], # type: ignore
jacobian=jacobian,
)
cmdstan_args = CmdStanArgs(
model_name=config_dict['model'],
model_exe=config_dict['model'],
model_name=model,
model_exe=model,
chain_ids=None,
method_args=laplace_args,
)
Expand All @@ -237,18 +245,18 @@ def from_csv(
for i in range(len(runset._retcodes)):
runset._set_retcode(i, 0)
mode: CmdStanMLE = from_csv(
config_dict['mode'],
config_dict['mode'], # type: ignore
method='optimize',
) # type: ignore
return CmdStanLaplace(runset, mode=mode)
elif config_dict['method'] == 'pathfinder':
pathfinder_args = PathfinderArgs(
num_draws=config_dict['num_draws'],
num_paths=config_dict['num_paths'],
num_draws=config_dict['num_draws'], # type: ignore
num_paths=config_dict['num_paths'], # type: ignore
)
cmdstan_args = CmdStanArgs(
model_name=config_dict['model'],
model_exe=config_dict['model'],
model_name=model,
model_exe=model,
chain_ids=None,
method_args=pathfinder_args,
)
Expand Down
91 changes: 45 additions & 46 deletions cmdstanpy/stanfit/gq.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,12 @@


from cmdstanpy.cmdstan_args import Method
from cmdstanpy.utils import build_xarray_data, flatten_chains, get_logger
from cmdstanpy.utils.stancsv import scan_generic_csv
from cmdstanpy.utils import (
build_xarray_data,
flatten_chains,
get_logger,
stancsv,
)

from .mcmc import CmdStanMCMC
from .metadata import InferenceMetadata
Expand Down Expand Up @@ -65,8 +69,7 @@ def __init__(
self.previous_fit: Fit = previous_fit

self._draws: np.ndarray = np.array(())
config = self._validate_csv_files()
self._metadata = InferenceMetadata(config)
self._metadata = self._validate_csv_files()

def __repr__(self) -> str:
repr = 'CmdStanGQ: model={} chains={}{}'.format(
Expand Down Expand Up @@ -99,48 +102,38 @@ def __getstate__(self) -> dict:
self._assemble_generated_quantities()
return self.__dict__

def _validate_csv_files(self) -> Dict[str, Any]:
def _validate_csv_files(self) -> InferenceMetadata:
"""
Checks that Stan CSV output files for all chains are consistent
and returns dict containing config and column names.
and returns InferenceMetadata object containing config and column names.

Raises exception when inconsistencies detected.
Raises exception if inconsistencies are detected.
"""
dzero = {}
for i in range(self.chains):
if i == 0:
dzero = scan_generic_csv(
path=self.runset.csv_files[i],
)
else:
drest = scan_generic_csv(
path=self.runset.csv_files[i],
)
for key in dzero:
if (
key
not in [
'id',
'fitted_params',
'diagnostic_file',
'metric_file',
'profile_file',
'init',
'seed',
'start_datetime',
]
and dzero[key] != drest[key]
):
raise ValueError(
'CmdStan config mismatch in Stan CSV file {}: '
'arg {} is {}, expected {}'.format(
self.runset.csv_files[i],
key,
dzero[key],
drest[key],
)
excluded_fields = {
'id',
'fitted_params',
'diagnostic_file',
'metric_file',
'profile_file',
'init',
'seed',
'start_datetime',
}
meta0 = InferenceMetadata.from_csv(self.runset.csv_files[0])
for i in range(1, self.chains):
meta = InferenceMetadata.from_csv(self.runset.csv_files[i])
for key in set(meta._cmdstan_config.keys()) - excluded_fields:
if meta0[key] != meta[key]:
raise ValueError(
'CmdStan config mismatch in Stan CSV file {}: '
'arg {} is {}, expected {}'.format(
self.runset.csv_files[i],
key,
meta0[key],
meta[key],
)
return dzero
)
return meta0

@property
def chains(self) -> int:
Expand All @@ -157,7 +150,7 @@ def column_names(self) -> Tuple[str, ...]:
"""
Names of generated quantities of interest.
"""
return self._metadata.cmdstan_config['column_names'] # type: ignore
return self._metadata.column_names

@property
def metadata(self) -> InferenceMetadata:
Expand Down Expand Up @@ -633,11 +626,17 @@ def _assemble_generated_quantities(self) -> None:
order='F',
)
for chain in range(self.chains):
with open(self.runset.csv_files[chain], 'r') as fd:
lines = (line for line in fd if not line.startswith('#'))
gq_sample[:, chain, :] = np.loadtxt(
lines, dtype=np.ndarray, ndmin=2, skiprows=1, delimiter=','
csv_file = self.runset.csv_files[chain]
try:
*_, draws = stancsv.parse_comments_header_and_draws(
self.runset.csv_files[chain]
)
gq_sample[:, chain, :] = stancsv.csv_bytes_list_to_numpy(draws)
except Exception as exc:
raise ValueError(
f"An error occurred when parsing Stan csv {csv_file}"
f" for chain {chain}"
) from exc
self._draws = gq_sample

def _draws_start(self, inc_warmup: bool) -> Tuple[int, int]:
Expand Down
27 changes: 12 additions & 15 deletions cmdstanpy/stanfit/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
XARRAY_INSTALLED = False

from cmdstanpy.cmdstan_args import Method
from cmdstanpy.utils import stancsv
from cmdstanpy.utils.data_munging import build_xarray_data
from cmdstanpy.utils.stancsv import scan_generic_csv

from .metadata import InferenceMetadata
from .mle import CmdStanMLE
Expand All @@ -46,11 +46,8 @@ def __init__(self, runset: RunSet, mode: CmdStanMLE) -> None:
)
self._runset = runset
self._mode = mode

self._draws: np.ndarray = np.array(())

config = scan_generic_csv(runset.csv_files[0])
self._metadata = InferenceMetadata(config)
self._metadata = InferenceMetadata.from_csv(self._runset.csv_files[0])

def create_inits(
self, seed: Optional[int] = None, chains: int = 4
Expand Down Expand Up @@ -89,16 +86,16 @@ def _assemble_draws(self) -> None:
if self._draws.shape != (0,):
return

with open(self._runset.csv_files[0], 'r') as fd:
while (fd.readline()).startswith("#"):
pass
self._draws = np.loadtxt(
fd,
dtype=float,
ndmin=2,
delimiter=',',
comments="#",
csv_file = self._runset.csv_files[0]
try:
*_, draws = stancsv.parse_comments_header_and_draws(
self._runset.csv_files[0]
)
self._draws = stancsv.csv_bytes_list_to_numpy(draws)
except Exception as exc:
raise ValueError(
f"An error occurred when parsing Stan csv {csv_file}"
) from exc

def stan_variable(self, var: str) -> np.ndarray:
"""
Expand Down Expand Up @@ -318,7 +315,7 @@ def column_names(self) -> Tuple[str, ...]:
and quantities of interest. Corresponds to Stan CSV file header row,
with names munged to array notation, e.g. `beta[1]` not `beta.1`.
"""
return self._metadata.cmdstan_config['column_names'] # type: ignore
return self._metadata.column_names

def save_csvfiles(self, dir: Optional[str] = None) -> None:
"""
Expand Down
Loading
Loading