diff --git a/cmdstanpy/stanfit/__init__.py b/cmdstanpy/stanfit/__init__.py index 50764a30..2ed527c3 100644 --- a/cmdstanpy/stanfit/__init__.py +++ b/cmdstanpy/stanfit/__init__.py @@ -2,7 +2,7 @@ import glob import os -from typing import Any, Dict, List, Optional, Union +from typing import List, Optional, Union from cmdstanpy.cmdstan_args import ( CmdStanArgs, @@ -12,7 +12,7 @@ SamplerArgs, VariationalArgs, ) -from cmdstanpy.utils import check_sampler_csv, get_logger, scan_config +from cmdstanpy.utils import check_sampler_csv, get_logger, stancsv from .gq import CmdStanGQ from .laplace import CmdStanLaplace @@ -103,10 +103,9 @@ def from_csv( ' includes non-csv file: {}'.format(file) ) - config_dict: Dict[str, Any] = {} try: - with open(csvfiles[0], 'r') as fd: - scan_config(fd, config_dict, 0) + comments, *_ = stancsv.parse_comments_header_and_draws(csvfiles[0]) + config_dict = stancsv.parse_config(comments) except (IOError, OSError, PermissionError) as e: raise ValueError('Cannot read CSV file: {}'.format(csvfiles[0])) from e if 'model' not in config_dict or 'method' not in config_dict: @@ -118,39 +117,43 @@ def from_csv( method, config_dict['method'] ) ) + model: str = config_dict['model'] # type: ignore try: if config_dict['method'] == 'sample': + save_warmup = config_dict['save_warmup'] == 1 chains = len(csvfiles) + num_samples: int = config_dict['num_samples'] # type: ignore + num_warmup: int = config_dict['num_warmup'] # type: ignore + thin: int = config_dict['thin'] # type: ignore sampler_args = SamplerArgs( - iter_sampling=config_dict['num_samples'], - iter_warmup=config_dict['num_warmup'], - thin=config_dict['thin'], - save_warmup=config_dict['save_warmup'], + iter_sampling=num_samples, + iter_warmup=num_warmup, + thin=thin, + save_warmup=save_warmup, ) # bugfix 425, check for fixed_params output try: check_sampler_csv( csvfiles[0], - iter_sampling=config_dict['num_samples'], - iter_warmup=config_dict['num_warmup'], - thin=config_dict['thin'], - save_warmup=config_dict['save_warmup'], + iter_sampling=num_samples, + iter_warmup=num_warmup, + thin=thin, + save_warmup=save_warmup, ) except ValueError: try: check_sampler_csv( csvfiles[0], - is_fixed_param=True, - iter_sampling=config_dict['num_samples'], - iter_warmup=config_dict['num_warmup'], - thin=config_dict['thin'], - save_warmup=config_dict['save_warmup'], + iter_sampling=num_samples, + iter_warmup=num_warmup, + thin=thin, + save_warmup=save_warmup, ) sampler_args = SamplerArgs( - iter_sampling=config_dict['num_samples'], - iter_warmup=config_dict['num_warmup'], - thin=config_dict['thin'], - save_warmup=config_dict['save_warmup'], + iter_sampling=num_samples, + iter_warmup=num_warmup, + thin=thin, + save_warmup=save_warmup, fixed_param=True, ) except ValueError as e: @@ -159,8 +162,8 @@ def from_csv( ) from e cmdstan_args = CmdStanArgs( - model_name=config_dict['model'], - model_exe=config_dict['model'], + model_name=model, + model_exe=model, chain_ids=[x + 1 for x in range(chains)], method_args=sampler_args, ) @@ -177,14 +180,18 @@ def from_csv( "Cannot find optimization algorithm" " in file {}.".format(csvfiles[0]) ) + algorithm: str = config_dict['algorithm'] # type: ignore + save_iterations = config_dict['save_iterations'] == 1 + jacobian = config_dict.get('jacobian', 0) == 1 + optimize_args = OptimizeArgs( - algorithm=config_dict['algorithm'], - save_iterations=config_dict['save_iterations'], - jacobian=config_dict.get('jacobian', 0), + algorithm=algorithm, + save_iterations=save_iterations, + jacobian=jacobian, ) cmdstan_args = CmdStanArgs( - model_name=config_dict['model'], - model_exe=config_dict['model'], + model_name=model, + model_exe=model, chain_ids=None, method_args=optimize_args, ) @@ -200,18 +207,18 @@ def from_csv( " in file {}.".format(csvfiles[0]) ) variational_args = VariationalArgs( - algorithm=config_dict['algorithm'], - iter=config_dict['iter'], - grad_samples=config_dict['grad_samples'], - elbo_samples=config_dict['elbo_samples'], - eta=config_dict['eta'], - tol_rel_obj=config_dict['tol_rel_obj'], - eval_elbo=config_dict['eval_elbo'], - output_samples=config_dict['output_samples'], + algorithm=config_dict['algorithm'], # type: ignore + iter=config_dict['iter'], # type: ignore + grad_samples=config_dict['grad_samples'], # type: ignore + elbo_samples=config_dict['elbo_samples'], # type: ignore + eta=config_dict['eta'], # type: ignore + tol_rel_obj=config_dict['tol_rel_obj'], # type: ignore + eval_elbo=config_dict['eval_elbo'], # type: ignore + output_samples=config_dict['output_samples'], # type: ignore ) cmdstan_args = CmdStanArgs( - model_name=config_dict['model'], - model_exe=config_dict['model'], + model_name=model, + model_exe=model, chain_ids=None, method_args=variational_args, ) @@ -221,14 +228,15 @@ def from_csv( runset._set_retcode(i, 0) return CmdStanVB(runset) elif config_dict['method'] == 'laplace': + jacobian = config_dict['jacobian'] == 1 laplace_args = LaplaceArgs( - mode=config_dict['mode'], - draws=config_dict['draws'], - jacobian=config_dict['jacobian'], + mode=config_dict['mode'], # type: ignore + draws=config_dict['draws'], # type: ignore + jacobian=jacobian, ) cmdstan_args = CmdStanArgs( - model_name=config_dict['model'], - model_exe=config_dict['model'], + model_name=model, + model_exe=model, chain_ids=None, method_args=laplace_args, ) @@ -237,18 +245,18 @@ def from_csv( for i in range(len(runset._retcodes)): runset._set_retcode(i, 0) mode: CmdStanMLE = from_csv( - config_dict['mode'], + config_dict['mode'], # type: ignore method='optimize', ) # type: ignore return CmdStanLaplace(runset, mode=mode) elif config_dict['method'] == 'pathfinder': pathfinder_args = PathfinderArgs( - num_draws=config_dict['num_draws'], - num_paths=config_dict['num_paths'], + num_draws=config_dict['num_draws'], # type: ignore + num_paths=config_dict['num_paths'], # type: ignore ) cmdstan_args = CmdStanArgs( - model_name=config_dict['model'], - model_exe=config_dict['model'], + model_name=model, + model_exe=model, chain_ids=None, method_args=pathfinder_args, ) diff --git a/cmdstanpy/stanfit/gq.py b/cmdstanpy/stanfit/gq.py index 6c77ec95..40658ced 100644 --- a/cmdstanpy/stanfit/gq.py +++ b/cmdstanpy/stanfit/gq.py @@ -31,8 +31,12 @@ from cmdstanpy.cmdstan_args import Method -from cmdstanpy.utils import build_xarray_data, flatten_chains, get_logger -from cmdstanpy.utils.stancsv import scan_generic_csv +from cmdstanpy.utils import ( + build_xarray_data, + flatten_chains, + get_logger, + stancsv, +) from .mcmc import CmdStanMCMC from .metadata import InferenceMetadata @@ -65,8 +69,7 @@ def __init__( self.previous_fit: Fit = previous_fit self._draws: np.ndarray = np.array(()) - config = self._validate_csv_files() - self._metadata = InferenceMetadata(config) + self._metadata = self._validate_csv_files() def __repr__(self) -> str: repr = 'CmdStanGQ: model={} chains={}{}'.format( @@ -99,48 +102,38 @@ def __getstate__(self) -> dict: self._assemble_generated_quantities() return self.__dict__ - def _validate_csv_files(self) -> Dict[str, Any]: + def _validate_csv_files(self) -> InferenceMetadata: """ Checks that Stan CSV output files for all chains are consistent - and returns dict containing config and column names. + and returns InferenceMetadata object containing config and column names. - Raises exception when inconsistencies detected. + Raises exception if inconsistencies are detected. """ - dzero = {} - for i in range(self.chains): - if i == 0: - dzero = scan_generic_csv( - path=self.runset.csv_files[i], - ) - else: - drest = scan_generic_csv( - path=self.runset.csv_files[i], - ) - for key in dzero: - if ( - key - not in [ - 'id', - 'fitted_params', - 'diagnostic_file', - 'metric_file', - 'profile_file', - 'init', - 'seed', - 'start_datetime', - ] - and dzero[key] != drest[key] - ): - raise ValueError( - 'CmdStan config mismatch in Stan CSV file {}: ' - 'arg {} is {}, expected {}'.format( - self.runset.csv_files[i], - key, - dzero[key], - drest[key], - ) + excluded_fields = { + 'id', + 'fitted_params', + 'diagnostic_file', + 'metric_file', + 'profile_file', + 'init', + 'seed', + 'start_datetime', + } + meta0 = InferenceMetadata.from_csv(self.runset.csv_files[0]) + for i in range(1, self.chains): + meta = InferenceMetadata.from_csv(self.runset.csv_files[i]) + for key in set(meta._cmdstan_config.keys()) - excluded_fields: + if meta0[key] != meta[key]: + raise ValueError( + 'CmdStan config mismatch in Stan CSV file {}: ' + 'arg {} is {}, expected {}'.format( + self.runset.csv_files[i], + key, + meta0[key], + meta[key], ) - return dzero + ) + return meta0 @property def chains(self) -> int: @@ -157,7 +150,7 @@ def column_names(self) -> Tuple[str, ...]: """ Names of generated quantities of interest. """ - return self._metadata.cmdstan_config['column_names'] # type: ignore + return self._metadata.column_names @property def metadata(self) -> InferenceMetadata: @@ -633,11 +626,17 @@ def _assemble_generated_quantities(self) -> None: order='F', ) for chain in range(self.chains): - with open(self.runset.csv_files[chain], 'r') as fd: - lines = (line for line in fd if not line.startswith('#')) - gq_sample[:, chain, :] = np.loadtxt( - lines, dtype=np.ndarray, ndmin=2, skiprows=1, delimiter=',' + csv_file = self.runset.csv_files[chain] + try: + *_, draws = stancsv.parse_comments_header_and_draws( + self.runset.csv_files[chain] ) + gq_sample[:, chain, :] = stancsv.csv_bytes_list_to_numpy(draws) + except Exception as exc: + raise ValueError( + f"An error occurred when parsing Stan csv {csv_file}" + f" for chain {chain}" + ) from exc self._draws = gq_sample def _draws_start(self, inc_warmup: bool) -> Tuple[int, int]: diff --git a/cmdstanpy/stanfit/laplace.py b/cmdstanpy/stanfit/laplace.py index 00e5199a..bad0c9a6 100644 --- a/cmdstanpy/stanfit/laplace.py +++ b/cmdstanpy/stanfit/laplace.py @@ -24,8 +24,8 @@ XARRAY_INSTALLED = False from cmdstanpy.cmdstan_args import Method +from cmdstanpy.utils import stancsv from cmdstanpy.utils.data_munging import build_xarray_data -from cmdstanpy.utils.stancsv import scan_generic_csv from .metadata import InferenceMetadata from .mle import CmdStanMLE @@ -46,11 +46,8 @@ def __init__(self, runset: RunSet, mode: CmdStanMLE) -> None: ) self._runset = runset self._mode = mode - self._draws: np.ndarray = np.array(()) - - config = scan_generic_csv(runset.csv_files[0]) - self._metadata = InferenceMetadata(config) + self._metadata = InferenceMetadata.from_csv(self._runset.csv_files[0]) def create_inits( self, seed: Optional[int] = None, chains: int = 4 @@ -89,16 +86,16 @@ def _assemble_draws(self) -> None: if self._draws.shape != (0,): return - with open(self._runset.csv_files[0], 'r') as fd: - while (fd.readline()).startswith("#"): - pass - self._draws = np.loadtxt( - fd, - dtype=float, - ndmin=2, - delimiter=',', - comments="#", + csv_file = self._runset.csv_files[0] + try: + *_, draws = stancsv.parse_comments_header_and_draws( + self._runset.csv_files[0] ) + self._draws = stancsv.csv_bytes_list_to_numpy(draws) + except Exception as exc: + raise ValueError( + f"An error occurred when parsing Stan csv {csv_file}" + ) from exc def stan_variable(self, var: str) -> np.ndarray: """ @@ -318,7 +315,7 @@ def column_names(self) -> Tuple[str, ...]: and quantities of interest. Corresponds to Stan CSV file header row, with names munged to array notation, e.g. `beta[1]` not `beta.1`. """ - return self._metadata.cmdstan_config['column_names'] # type: ignore + return self._metadata.column_names def save_csvfiles(self, dir: Optional[str] = None) -> None: """ diff --git a/cmdstanpy/stanfit/mcmc.py b/cmdstanpy/stanfit/mcmc.py index 91f74d36..0d0f3e0b 100644 --- a/cmdstanpy/stanfit/mcmc.py +++ b/cmdstanpy/stanfit/mcmc.py @@ -218,7 +218,7 @@ def column_names(self) -> Tuple[str, ...]: and quantities of interest. Corresponds to Stan CSV file header row, with names munged to array notation, e.g. `beta[1]` not `beta.1`. """ - return self._metadata.cmdstan_config['column_names'] # type: ignore + return self._metadata.column_names @property def metric_type(self) -> Optional[str]: @@ -345,7 +345,6 @@ def _validate_csv_files(self) -> Dict[str, Any]: if i == 0: dzero = check_sampler_csv( path=self.runset.csv_files[i], - is_fixed_param=self._is_fixed_param, iter_sampling=self._iter_sampling, iter_warmup=self._iter_warmup, save_warmup=self._save_warmup, @@ -358,7 +357,6 @@ def _validate_csv_files(self) -> Dict[str, Any]: else: drest = check_sampler_csv( path=self.runset.csv_files[i], - is_fixed_param=self._is_fixed_param, iter_sampling=self._iter_sampling, iter_warmup=self._iter_warmup, save_warmup=self._save_warmup, @@ -444,14 +442,20 @@ def _assemble_draws(self) -> None: mass_matrix_per_chain = [] for chain in range(self.chains): try: - comments, draws = stancsv.parse_stan_csv_comments_and_draws( + ( + comments, + header, + draws, + ) = stancsv.parse_comments_header_and_draws( self.runset.csv_files[chain] ) - self._draws[:, chain, :] = stancsv.csv_bytes_list_to_numpy( - draws - ) + draws_np = stancsv.csv_bytes_list_to_numpy(draws) + if draws_np.shape[0] == 0: + n_cols = header.count(",") + 1 # type: ignore + draws_np = np.empty((0, n_cols)) + self._draws[:, chain, :] = draws_np if not self._is_fixed_param: ( self._step_size[chain], diff --git a/cmdstanpy/stanfit/metadata.py b/cmdstanpy/stanfit/metadata.py index 4869f2a0..61725116 100644 --- a/cmdstanpy/stanfit/metadata.py +++ b/cmdstanpy/stanfit/metadata.py @@ -1,10 +1,13 @@ """Container for metadata parsed from the output of a CmdStan run""" import copy -from typing import Any, Dict +import os +from typing import Any, Dict, Iterator, Tuple, Union import stanio +from cmdstanpy.utils import stancsv + class InferenceMetadata: """ @@ -13,10 +16,13 @@ class InferenceMetadata: Assumes valid CSV files. """ - def __init__(self, config: Dict[str, Any]) -> None: + def __init__( + self, config: Dict[str, Union[str, int, float, Tuple[str, ...]]] + ) -> None: """Initialize object from CSV headers""" self._cmdstan_config = config - vars = stanio.parse_header(config['raw_header']) + + vars = stanio.parse_header(config['raw_header']) # type: ignore self._method_vars = { k: v for (k, v) in vars.items() if k.endswith('__') @@ -25,9 +31,26 @@ def __init__(self, config: Dict[str, Any]) -> None: k: v for (k, v) in vars.items() if not k.endswith('__') } + @classmethod + def from_csv( + cls, stan_csv: Union[str, os.PathLike, Iterator[bytes]] + ) -> 'InferenceMetadata': + try: + comments, header, _ = stancsv.parse_comments_header_and_draws( + stan_csv + ) + return cls(stancsv.construct_config_header_dict(comments, header)) + except Exception as exc: + raise ValueError( + f"An error occurred when parsing Stan csv {stan_csv}" + ) from exc + def __repr__(self) -> str: return 'Metadata:\n{}\n'.format(self._cmdstan_config) + def __getitem__(self, key: str) -> Union[str, int, float, Tuple[str, ...]]: + return self._cmdstan_config[key] + @property def cmdstan_config(self) -> Dict[str, Any]: """ @@ -38,6 +61,11 @@ def cmdstan_config(self) -> Dict[str, Any]: """ return copy.deepcopy(self._cmdstan_config) + @property + def column_names(self) -> Tuple[str, ...]: + col_names = self['column_names'] + return col_names # type: ignore + @property def method_vars(self) -> Dict[str, stanio.Variable]: """ diff --git a/cmdstanpy/stanfit/mle.py b/cmdstanpy/stanfit/mle.py index fd599dbf..8f28fc3d 100644 --- a/cmdstanpy/stanfit/mle.py +++ b/cmdstanpy/stanfit/mle.py @@ -7,7 +7,7 @@ import pandas as pd from cmdstanpy.cmdstan_args import Method, OptimizeArgs -from cmdstanpy.utils import get_logger, scan_optimize_csv +from cmdstanpy.utils import get_logger, stancsv from .metadata import InferenceMetadata from .runset import RunSet @@ -34,7 +34,28 @@ def __init__(self, runset: RunSet) -> None: optimize_args, OptimizeArgs ) # make the typechecker happy self._save_iterations: bool = optimize_args.save_iterations - self._set_mle_attrs(runset.csv_files[0]) + + csv_file = self.runset.csv_files[0] + try: + ( + comment_lines, + header, + draws_lines, + ) = stancsv.parse_comments_header_and_draws( + self.runset.csv_files[0] + ) + self._metadata = InferenceMetadata( + stancsv.construct_config_header_dict(comment_lines, header) + ) + all_draws = stancsv.csv_bytes_list_to_numpy(draws_lines) + + except Exception as exc: + raise ValueError( + f"An error occurred when parsing Stan csv {csv_file}" + ) from exc + self._mle: np.ndarray = all_draws[-1] + if self._save_iterations: + self._all_iters: np.ndarray = all_draws def create_inits( self, seed: Optional[int] = None, chains: int = 4 @@ -84,21 +105,13 @@ def __getattr__(self, attr: str) -> Union[np.ndarray, float]: # pylint: disable=raise-missing-from raise AttributeError(*e.args) - def _set_mle_attrs(self, sample_csv_0: str) -> None: - meta = scan_optimize_csv(sample_csv_0, self._save_iterations) - self._metadata = InferenceMetadata(meta) - self._column_names: Tuple[str, ...] = meta['column_names'] - self._mle: np.ndarray = meta['mle'] - if self._save_iterations: - self._all_iters: np.ndarray = meta['all_iters'] - @property def column_names(self) -> Tuple[str, ...]: """ Names of estimated quantities, includes joint log probability, and all parameters, transformed parameters, and generated quantities. """ - return self._column_names + return self.metadata.column_names @property def metadata(self) -> InferenceMetadata: diff --git a/cmdstanpy/stanfit/pathfinder.py b/cmdstanpy/stanfit/pathfinder.py index 5ac4d213..bbedc146 100644 --- a/cmdstanpy/stanfit/pathfinder.py +++ b/cmdstanpy/stanfit/pathfinder.py @@ -9,7 +9,7 @@ from cmdstanpy.cmdstan_args import Method from cmdstanpy.stanfit.metadata import InferenceMetadata from cmdstanpy.stanfit.runset import RunSet -from cmdstanpy.utils.stancsv import scan_generic_csv +from cmdstanpy.utils import stancsv class CmdStanPathfinder: @@ -26,11 +26,8 @@ def __init__(self, runset: RunSet): 'found method {}'.format(runset.method) ) self._runset = runset - self._draws: np.ndarray = np.array(()) - - config = scan_generic_csv(runset.csv_files[0]) - self._metadata = InferenceMetadata(config) + self._metadata = InferenceMetadata.from_csv(self._runset.csv_files[0]) def create_inits( self, seed: Optional[int] = None, chains: int = 4 @@ -77,21 +74,20 @@ def __repr__(self) -> str: ) return rep - # below this is identical to same functions in Laplace def _assemble_draws(self) -> None: if self._draws.shape != (0,): return - with open(self._runset.csv_files[0], 'r') as fd: - while (fd.readline()).startswith("#"): - pass - self._draws = np.loadtxt( - fd, - dtype=float, - ndmin=2, - delimiter=',', - comments="#", + csv_file = self._runset.csv_files[0] + try: + *_, draws = stancsv.parse_comments_header_and_draws( + self._runset.csv_files[0] ) + self._draws = stancsv.csv_bytes_list_to_numpy(draws) + except Exception as exc: + raise ValueError( + f"An error occurred when parsing Stan csv {csv_file}" + ) from exc def stan_variable(self, var: str) -> np.ndarray: """ @@ -204,7 +200,7 @@ def column_names(self) -> Tuple[str, ...]: and quantities of interest. Corresponds to Stan CSV file header row, with names munged to array notation, e.g. `beta[1]` not `beta.1`. """ - return self._metadata.cmdstan_config['column_names'] # type: ignore + return self._metadata.column_names @property def is_resampled(self) -> bool: diff --git a/cmdstanpy/stanfit/vb.py b/cmdstanpy/stanfit/vb.py index 8d7ac552..2c8f3e20 100644 --- a/cmdstanpy/stanfit/vb.py +++ b/cmdstanpy/stanfit/vb.py @@ -7,7 +7,7 @@ import pandas as pd from cmdstanpy.cmdstan_args import Method -from cmdstanpy.utils import scan_variational_csv +from cmdstanpy.utils import stancsv from cmdstanpy.utils.logging import get_logger from .metadata import InferenceMetadata @@ -28,7 +28,30 @@ def __init__(self, runset: RunSet) -> None: 'found method {}'.format(runset.method) ) self.runset = runset - self._set_variational_attrs(runset.csv_files[0]) + + csv_file = self.runset.csv_files[0] + try: + ( + comment_lines, + header, + draw_lines, + ) = stancsv.parse_comments_header_and_draws( + self.runset.csv_files[0] + ) + + self._metadata = InferenceMetadata( + stancsv.construct_config_header_dict(comment_lines, header) + ) + self._eta = stancsv.parse_variational_eta(comment_lines) + + draws_np = stancsv.csv_bytes_list_to_numpy(draw_lines) + + except Exception as exc: + raise ValueError( + f"An error occurred when parsing Stan csv {csv_file}" + ) from exc + self._variational_mean: np.ndarray = draws_np[0] + self._variational_sample: np.ndarray = draws_np[1:] def create_inits( self, seed: Optional[int] = None, chains: int = 4 @@ -87,15 +110,6 @@ def __getattr__(self, attr: str) -> Union[np.ndarray, float]: # pylint: disable=raise-missing-from raise AttributeError(*e.args) - def _set_variational_attrs(self, sample_csv_0: str) -> None: - meta = scan_variational_csv(sample_csv_0) - self._metadata = InferenceMetadata(meta) - # these three assignments don't grant type information - self._column_names: Tuple[str, ...] = meta['column_names'] - self._eta: float = meta['eta'] - self._variational_mean: np.ndarray = meta['variational_mean'] - self._variational_sample: np.ndarray = meta['variational_sample'] - @property def columns(self) -> int: """ @@ -103,7 +117,7 @@ def columns(self) -> int: Includes approximation information and names of model parameters and computed quantities. """ - return len(self._column_names) + return len(self.column_names) @property def column_names(self) -> Tuple[str, ...]: @@ -112,7 +126,7 @@ def column_names(self) -> Tuple[str, ...]: Includes approximation information and names of model parameters and computed quantities. """ - return self._column_names + return self.metadata.column_names @property def eta(self) -> float: diff --git a/cmdstanpy/utils/__init__.py b/cmdstanpy/utils/__init__.py index 8d61a289..570c4f9a 100644 --- a/cmdstanpy/utils/__init__.py +++ b/cmdstanpy/utils/__init__.py @@ -1,6 +1,7 @@ """ Utility functions """ + import os import platform import sys @@ -28,21 +29,8 @@ windows_short_path, ) from .json import write_stan_json -from .logging import get_logger, enable_logging, disable_logging -from .stancsv import ( - check_sampler_csv, - parse_rdump_value, - read_metric, - rload, - scan_column_names, - scan_config, - scan_hmc_params, - scan_optimize_csv, - scan_sampler_csv, - scan_sampling_iters, - scan_variational_csv, - scan_warmup_iters, -) +from .logging import disable_logging, enable_logging, get_logger +from .stancsv import check_sampler_csv, parse_rdump_value, read_metric, rload def show_versions(output: bool = True) -> str: @@ -128,14 +116,6 @@ def show_versions(output: bool = True) -> str: 'read_metric', 'returncode_msg', 'rload', - 'scan_column_names', - 'scan_config', - 'scan_hmc_params', - 'scan_optimize_csv', - 'scan_sampler_csv', - 'scan_sampling_iters', - 'scan_variational_csv', - 'scan_warmup_iters', 'set_cmdstan_path', 'set_make_env', 'show_versions', diff --git a/cmdstanpy/utils/stancsv.py b/cmdstanpy/utils/stancsv.py index 74830994..f68260fe 100644 --- a/cmdstanpy/utils/stancsv.py +++ b/cmdstanpy/utils/stancsv.py @@ -8,60 +8,69 @@ import os import re import warnings -from typing import ( - Any, - Dict, - Iterator, - List, - MutableMapping, - Optional, - TextIO, - Tuple, - Union, -) +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union import numpy as np import numpy.typing as npt -import pandas as pd from cmdstanpy import _CMDSTAN_SAMPLING, _CMDSTAN_THIN, _CMDSTAN_WARMUP -def parse_stan_csv_comments_and_draws( +def parse_comments_header_and_draws( stan_csv: Union[str, os.PathLike, Iterator[bytes]], -) -> Tuple[List[bytes], List[bytes]]: - """Parses lines of a Stan CSV file into comment lines and draws lines, where - a draws line is just a non-commented line. +) -> Tuple[List[bytes], Optional[str], List[bytes]]: + """Parses lines of a Stan CSV file into comment lines, the header line, + and draws lines. - Returns a (comment_lines, draws_lines) tuple. + Returns a (comment_lines, header, draws_lines) tuple. """ - def split_comments_and_draws( + def partition_csv( lines: Iterator[bytes], - ) -> Tuple[List[bytes], List[bytes]]: - comment_lines, draws_lines = [], [] + ) -> Tuple[List[bytes], Optional[str], List[bytes]]: + comment_lines: List[bytes] = [] + draws_lines: List[bytes] = [] + header = None for line in lines: if line.startswith(b"#"): # is comment line comment_lines.append(line) + elif header is None: # Assumes the header is the first non-comment + header = line.strip().decode() else: draws_lines.append(line) - return comment_lines, draws_lines + return comment_lines, header, draws_lines if isinstance(stan_csv, (str, os.PathLike)): with open(stan_csv, "rb") as f: - return split_comments_and_draws(f) + return partition_csv(f) else: - return split_comments_and_draws(stan_csv) + return partition_csv(stan_csv) + + +def filter_csv_bytes_by_columns( + csv_bytes_list: List[bytes], indexes_to_keep: List[int] +) -> List[bytes]: + """Given the list of bytes representing the lines of a CSV file + and the indexes of columns to keep, will return a new list of bytes + containing only those columns in the index order provided. Assumes + column-delimited columns.""" + out = [] + for dl in csv_bytes_list: + split = dl.strip().split(b",") + out.append(b",".join(split[i] for i in indexes_to_keep) + b"\n") + return out def csv_bytes_list_to_numpy( - csv_bytes_list: List[bytes], includes_header: bool = True + csv_bytes_list: List[bytes], ) -> npt.NDArray[np.float64]: """Efficiently converts a list of bytes representing whose concatenation - represents a CSV file into a numpy array. Includes header specifies - whether the bytes contains an initial header line.""" + represents a CSV file into a numpy array. + + Returns a 2D numpy array with shape (n_rows, n_cols). If no data is found, + returns an empty array with shape (0, 0).""" if not csv_bytes_list: - return np.empty((0,)) + return np.empty((0, 0)) num_cols = csv_bytes_list[0].count(b",") + 1 try: import polars as pl @@ -70,7 +79,7 @@ def csv_bytes_list_to_numpy( out: npt.NDArray[np.float64] = ( pl.read_csv( io.BytesIO(b"".join(csv_bytes_list)), - has_header=includes_header, + has_header=False, schema_overrides=[pl.Float64] * num_cols, infer_schema=False, ) @@ -78,29 +87,25 @@ def csv_bytes_list_to_numpy( .astype(np.float64) ) except pl.exceptions.NoDataError: - return np.empty((0, num_cols)) + return np.empty((0, 0)) except ImportError: with warnings.catch_warnings(): warnings.filterwarnings("ignore") out = np.loadtxt( csv_bytes_list, - skiprows=int(includes_header), delimiter=",", dtype=np.float64, - ndmin=1, + ndmin=2, ) - if len(out.shape) == 1: - if out.shape[0] == 0: # No data read - out = np.empty((0, num_cols)) - else: - out = out.reshape(1, -1) + if out.shape[0] == 0: # No data read + out = np.empty((0, 0)) return out def parse_hmc_adaptation_lines( comment_lines: List[bytes], -) -> Tuple[float, Optional[npt.NDArray[np.float64]]]: +) -> Tuple[Optional[float], Optional[npt.NDArray[np.float64]]]: """Extracts step size/mass matrix information from the Stan CSV comment lines by parsing the adaptation section. If the diag_e metric is used, the returned mass matrix will be a 1D array of the diagnoal elements, @@ -129,245 +134,385 @@ def parse_hmc_adaptation_lines( break elif b"diag_e" in line: diag_e_metric = True - if step_size is None: - raise ValueError("Unable to parse adapated step size") if matrix_lines: - mass_matrix = csv_bytes_list_to_numpy( - matrix_lines, includes_header=False - ) + mass_matrix = csv_bytes_list_to_numpy(matrix_lines) if diag_e_metric and mass_matrix.shape[0] == 1: mass_matrix = mass_matrix[0] return step_size, mass_matrix +def extract_key_val_pairs( + comment_lines: List[bytes], remove_default_text: bool = True +) -> Iterator[Tuple[str, str]]: + """Yields cleaned key = val pairs from stan csv comments. + Removes '(Default)' text from values if remove_default_text is True.""" + cleaned_lines = ( + line.decode().lstrip("# ").strip() for line in comment_lines + ) + for line in cleaned_lines: + split_on_eq = line.split(" = ") + # Only want lines with key = value + if len(split_on_eq) != 2: + continue + + key, val = split_on_eq + if remove_default_text: + val = val.replace("(Default)", "").strip() + yield key, val + + +def parse_config( + comment_lines: List[bytes], +) -> Dict[str, Union[str, int, float]]: + """Extracts the key=value config settings from Stan CSV comment + lines and returns a dictionary.""" + out: Dict[str, Union[str, int, float]] = {} + for key, val in extract_key_val_pairs(comment_lines): + if key == 'file': + if not val.endswith('csv'): + out['data_file'] = val + else: + if val == 'true': + out[key] = 1 + elif val == 'false': + out[key] = 0 + else: + for cast in (int, float): + try: + out[key] = cast(val) + break + except ValueError: + pass + else: + out[key] = val + return out + + +def parse_header(header: str) -> Tuple[str, ...]: + """Returns munged variable names from a Stan csv header line""" + return tuple(munge_varname(name) for name in header.split(",")) + + +def construct_config_header_dict( + comment_lines: List[bytes], header: Optional[str] +) -> Dict[str, Union[str, int, float, Tuple[str, ...]]]: + """Extracts config and header info from comment/draws lines parsed + from a Stan CSV file.""" + config = parse_config(comment_lines) + out: Dict[str, Union[str, int, float, Tuple[str, ...]]] = {**config} + if header: + out["raw_header"] = header + out["column_names"] = parse_header(header) + return out + + +def parse_variational_eta(comment_lines: List[bytes]) -> float: + """Extracts the variational eta parameter from stancsv comment lines""" + for i, line in enumerate(comment_lines): + if line.startswith(b"# Stepsize adaptation") and ( + i + 1 < len(comment_lines) # Ensure i + 1 is in bounds + ): + eta_line = comment_lines[i + 1] + break + else: + raise ValueError( + "Unable to parse eta from Stan CSV, adaptation block not found" + ) + + _, val = eta_line.split(b" = ") + return float(val) + + +def extract_max_treedepth_and_divergence_counts( + header: str, draws_lines: List[bytes], max_treedepth: int, warmup_draws: int +) -> Tuple[int, int]: + """Extracts the max treedepth and divergence counts from the header + and draw lines of the MCMC stan csv output.""" + if len(draws_lines) <= 1: # Empty draws + return 0, 0 + column_names = header.split(",") + + try: + indexes_to_keep = [ + column_names.index("treedepth__"), + column_names.index("divergent__"), + ] + except ValueError: + # Throws if treedepth/divergent columns not recorded + return 0, 0 + + sampling_draws = draws_lines[1 + warmup_draws :] + + filtered = filter_csv_bytes_by_columns(sampling_draws, indexes_to_keep) + arr = csv_bytes_list_to_numpy(filtered).astype(int) + + num_max_treedepth = np.sum(arr[:, 0] == max_treedepth) + num_divergences = np.sum(arr[:, 1]) + return num_max_treedepth, num_divergences + + +# TODO: Remove after CmdStan 2.37 is the minimum version +def is_sneaky_fixed_param(header: str) -> bool: + """Returns True if the header line indicates that the sampler + ran with the fixed_param sampler automatically, despite the + algorithm listed as 'hmc'. + + See issue #805""" + num_dunder_cols = sum(col.endswith("__") for col in header.split(",")) + + return (num_dunder_cols < 7) and "lp__" in header + + +def count_warmup_and_sampling_draws( + stan_csv: Union[str, os.PathLike, Iterator[bytes]], +) -> Tuple[int, int]: + """Scans through a Stan CSV file to count the number of lines in the + warmup/sampling blocks to determine counts for warmup and sampling draws. + """ + + def determine_draw_counts(lines: Iterator[bytes]) -> Tuple[int, int]: + is_fixed_param = False + header_line_idx = None + adaptation_block_idx = None + sampling_block_idx = None + timing_block_idx = None + for i, line in enumerate(lines): + if header_line_idx is None: + if b"fixed_param" in line: + is_fixed_param = True + if line.startswith(b"lp__"): + header_line_idx = i + if not is_fixed_param: + is_fixed_param = is_sneaky_fixed_param( + line.strip().decode() + ) + continue + + if not is_fixed_param and adaptation_block_idx is None: + if line.startswith(b"#"): + adaptation_block_idx = i + elif sampling_block_idx is None: + if not line.startswith(b"#"): + sampling_block_idx = i + elif line.startswith(b"# Elapsed"): + sampling_block_idx = i + timing_block_idx = i + elif timing_block_idx is None: + if line.startswith(b"#"): + timing_block_idx = i + else: + break + else: + # Will raise if lines exhausts without all blocks being identified + raise ValueError( + "Unable to count warmup and sampling draws from Stan csv" + ) + + if is_fixed_param: + num_warmup = 0 + else: + num_warmup = ( + adaptation_block_idx - header_line_idx - 1 # type: ignore + ) + num_sampling = timing_block_idx - sampling_block_idx + return num_warmup, num_sampling + + if isinstance(stan_csv, (str, os.PathLike)): + with open(stan_csv, "rb") as f: + return determine_draw_counts(f) + else: + return determine_draw_counts(stan_csv) + + +def raise_on_inconsistent_draws_shape( + header: str, draw_lines: List[bytes] +) -> None: + """Throws a ValueError if any draws are found to have an inconsistent + shape, i.e. too many/few columns compared to the header""" + + def column_count(ln: bytes) -> int: + return ln.count(b",") + 1 + + # Consider empty draws to be consistent + if not draw_lines: + return + + num_cols = column_count(header.encode()) + for i, draw in enumerate(draw_lines, start=1): + if (draw_size := column_count(draw)) != num_cols: + raise ValueError( + f"line {i}: bad draw, expecting {num_cols} items," + f" found {draw_size}" + ) + + +def raise_on_invalid_adaptation_block(comment_lines: List[bytes]) -> None: + """Throws ValueErrors if the parsed adaptation block is invalid, e.g. + the metric information is not present, consistent with the rest of + the file, or the step size info cannot be processed.""" + + def column_count(ln: bytes) -> int: + return ln.count(b",") + 1 + + ln_iter = enumerate(comment_lines, start=2) + metric = None + for _, line in ln_iter: + if b"metric =" in line: + _, val = line.split(b" = ") + metric = val.replace(b"(Default)", b"").strip().decode() + if b"Adaptation terminated" in line: + break + else: # No adaptation block found + raise ValueError("No adaptation block found, expecting metric") + + if metric is None: + raise ValueError("No reported metric found") + # At this point iterator should be in the adaptation block + + # Ensure step size exists and is valid float + num, line = next(ln_iter) + if not line.startswith(b"# Step size"): + raise ValueError( + f"line {num}: expecting step size, " + f"found:\n\t \"{line.decode()}\"" + ) + _, step_size = line.split(b" = ") + try: + float(step_size.strip()) + except ValueError as exc: + raise ValueError( + f"line {num}: invalid step size: {step_size.decode()}" + ) from exc + + # Ensure mass matrix valid + num, line = next(ln_iter) + if metric == "unit_e": + return + if not ( + (metric == "diag_e" and line.startswith(b"# Diagonal elements of ")) + or (metric == "dense_e" and line.startswith(b"# Elements of inverse")) + ): + raise ValueError( + f"line {num}: invalid or missing mass matrix specification" + ) + + # Validating mass matrix shape + _, line = next(ln_iter) + num_unconstrained_params = column_count(line) + if metric == "diag_e": + return + for (num, line), _ in zip(ln_iter, range(1, num_unconstrained_params)): + if column_count(line) != num_unconstrained_params: + raise ValueError( + f"line {num}: invalid or missing mass matrix specification" + ) + + +def parse_timing_lines( + comment_lines: List[bytes], +) -> Dict[str, float]: + """Parse the timing lines into a dictionary with key corresponding + to the phase, e.g. Warm-up, Sampling, Total, and value the elapsed seconds + """ + out: Dict[str, float] = {} + + cleaned_lines = (ln.lstrip(b"# ") for ln in comment_lines) + in_timing_block = False + for line in cleaned_lines: + if line.startswith(b"Elapsed Time") and not in_timing_block: + in_timing_block = True + + if not in_timing_block: + continue + match = re.findall(r"([\d\.]+) seconds \((.+)\)", str(line)) + if match: + seconds = float(match[0][0]) + phase = match[0][1] + out[phase] = seconds + return out + + def check_sampler_csv( - path: str, - is_fixed_param: bool = False, - iter_sampling: Optional[int] = None, - iter_warmup: Optional[int] = None, + path: Union[str, os.PathLike], + iter_sampling: int = _CMDSTAN_SAMPLING, + iter_warmup: int = _CMDSTAN_WARMUP, save_warmup: bool = False, - thin: Optional[int] = None, + thin: int = _CMDSTAN_THIN, ) -> Dict[str, Any]: """Capture essential config, shape from stan_csv file.""" - meta = scan_sampler_csv(path, is_fixed_param) - if thin is None: - thin = _CMDSTAN_THIN - elif thin > _CMDSTAN_THIN: + meta = parse_sampler_metadata_from_csv(path) + if thin > _CMDSTAN_THIN: if 'thin' not in meta: raise ValueError( - 'bad Stan CSV file {}, ' - 'config error, expected thin = {}'.format(path, thin) + f'Bad Stan CSV file {path}, config error, ' + f'expected thin = {thin}' ) if meta['thin'] != thin: raise ValueError( - 'bad Stan CSV file {}, ' - 'config error, expected thin = {}, found {}'.format( - path, thin, meta['thin'] - ) + f'Bad Stan CSV file {path}, ' + f'config error, expected thin = {thin}, found {meta["thin"]}' ) - draws_sampling = iter_sampling - if draws_sampling is None: - draws_sampling = _CMDSTAN_SAMPLING - draws_warmup = iter_warmup - if draws_warmup is None: - draws_warmup = _CMDSTAN_WARMUP - draws_warmup = int(math.ceil(draws_warmup / thin)) - draws_sampling = int(math.ceil(draws_sampling / thin)) + draws_warmup = int(math.ceil(iter_warmup / thin)) + draws_sampling = int(math.ceil(iter_sampling / thin)) if meta['draws_sampling'] != draws_sampling: raise ValueError( - 'bad Stan CSV file {}, expected {} draws, found {}'.format( - path, draws_sampling, meta['draws_sampling'] - ) + f'Bad Stan CSV file {path}, expected {draws_sampling} draws, ' + f'found {meta["draws_sampling"]}' ) if save_warmup: - if not ('save_warmup' in meta and meta['save_warmup'] in (1, 'true')): + if not ('save_warmup' in meta and meta['save_warmup'] == 1): raise ValueError( - 'bad Stan CSV file {}, ' - 'config error, expected save_warmup = 1'.format(path) + f'Bad Stan CSV file {path}, ' + 'config error, expected save_warmup = 1' ) if meta['draws_warmup'] != draws_warmup: raise ValueError( - 'bad Stan CSV file {}, ' - 'expected {} warmup draws, found {}'.format( - path, draws_warmup, meta['draws_warmup'] - ) + f'Bad Stan CSV file {path}, expected {draws_warmup} ' + f'warmup draws, found {meta["draws_warmup"]}' ) return meta -def scan_sampler_csv(path: str, is_fixed_param: bool = False) -> Dict[str, Any]: - """Process sampler stan_csv output file line by line.""" - dict: Dict[str, Any] = {} - lineno = 0 - with open(path, 'r') as fd: - try: - lineno = scan_config(fd, dict, lineno) - lineno = scan_column_names(fd, dict, lineno) - if not is_fixed_param: - lineno = scan_warmup_iters(fd, dict, lineno) - lineno = scan_hmc_params(fd, dict, lineno) - lineno = scan_sampling_iters(fd, dict, lineno, is_fixed_param) - lineno = scan_time(fd, dict, lineno) - except ValueError as e: - raise ValueError("Error in reading csv file: " + path) from e - return dict - - -def scan_optimize_csv(path: str, save_iters: bool = False) -> Dict[str, Any]: - """Process optimizer stan_csv output file line by line.""" - dict: Dict[str, Any] = {} - lineno = 0 - # scan to find config, header, num saved iters - with open(path, 'r') as fd: - lineno = scan_config(fd, dict, lineno) - lineno = scan_column_names(fd, dict, lineno) - iters = 0 - for line in fd: - iters += 1 - if save_iters: - all_iters: np.ndarray = np.empty( - (iters, len(dict['column_names'])), dtype=float, order='F' - ) - # rescan to capture estimates - with open(path, 'r') as fd: - for i in range(lineno): - fd.readline() - for i in range(iters): - line = fd.readline().strip() - if len(line) < 1: - raise ValueError( - 'cannot parse CSV file {}, error at line {}'.format( - path, lineno + i - ) - ) - xs = line.split(',') - if save_iters: - all_iters[i, :] = [float(x) for x in xs] - if i == iters - 1: - mle: np.ndarray = np.array(xs, dtype=float) - # pylint: disable=possibly-used-before-assignment - dict['mle'] = mle - if save_iters: - dict['all_iters'] = all_iters - return dict - - -def scan_generic_csv(path: str) -> Dict[str, Any]: - """Process laplace stan_csv output file line by line.""" - dict: Dict[str, Any] = {} - lineno = 0 - with open(path, 'r') as fd: - lineno = scan_config(fd, dict, lineno) - lineno = scan_column_names(fd, dict, lineno) - return dict - - -def scan_variational_csv(path: str) -> Dict[str, Any]: - """Process advi stan_csv output file line by line.""" - dict: Dict[str, Any] = {} - lineno = 0 - with open(path, 'r') as fd: - lineno = scan_config(fd, dict, lineno) - lineno = scan_column_names(fd, dict, lineno) - line = fd.readline().lstrip(' #\t').rstrip() - lineno += 1 - if line.startswith('Stepsize adaptation complete.'): - line = fd.readline().lstrip(' #\t\n') - lineno += 1 - if not line.startswith('eta'): - raise ValueError( - 'line {}: expecting eta, found:\n\t "{}"'.format( - lineno, line - ) - ) - _, eta = line.split('=') - dict['eta'] = float(eta) - line = fd.readline().lstrip(' #\t\n') - lineno += 1 - xs = line.split(',') - variational_mean = [float(x) for x in xs] - dict['variational_mean'] = np.array(variational_mean) - dict['variational_sample'] = pd.read_csv( - path, - comment='#', - skiprows=lineno, - header=None, - float_precision='high', - ).to_numpy() - return dict - - -def scan_config(fd: TextIO, config_dict: Dict[str, Any], lineno: int) -> int: - """ - Scan initial stan_csv file comments lines and - save non-default configuration information to config_dict. - """ - cur_pos = fd.tell() - line = fd.readline().strip() - while len(line) > 0 and line.startswith('#'): - lineno += 1 - if line.endswith('(Default)'): - line = line.replace('(Default)', '') - line = line.lstrip(' #\t') - key_val = line.split('=') - if len(key_val) == 2: - if key_val[0].strip() == 'file' and not key_val[1].endswith('csv'): - config_dict['data_file'] = key_val[1].strip() - elif key_val[0].strip() != 'file': - raw_val = key_val[1].strip() - val: Union[int, float, str] - try: - val = int(raw_val) - except ValueError: - try: - val = float(raw_val) - except ValueError: - if raw_val == "true": - val = 1 - elif raw_val == "false": - val = 0 - else: - val = raw_val - config_dict[key_val[0].strip()] = val - cur_pos = fd.tell() - line = fd.readline().strip() - fd.seek(cur_pos) - return lineno - - -def scan_warmup_iters( - fd: TextIO, config_dict: Dict[str, Any], lineno: int -) -> int: - """ - Check warmup iterations, if any. - """ - if 'save_warmup' not in config_dict: - return lineno - cur_pos = fd.tell() - line = fd.readline().strip() - draws_found = 0 - while len(line) > 0 and not line.startswith('#'): - lineno += 1 - draws_found += 1 - cur_pos = fd.tell() - line = fd.readline().strip() - fd.seek(cur_pos) - config_dict['draws_warmup'] = draws_found - return lineno - - -def scan_column_names( - fd: TextIO, config_dict: MutableMapping[str, Any], lineno: int -) -> int: - """ - Process columns header, add to config_dict as 'column_names' - """ - line = fd.readline().strip() - lineno += 1 - config_dict['raw_header'] = line.strip() - names = line.split(',') - config_dict['column_names'] = tuple(munge_varnames(names)) - return lineno +def parse_sampler_metadata_from_csv( + path: Union[str, os.PathLike], +) -> Dict[str, Union[int, float, str, Tuple[str, ...], Dict[str, float]]]: + """Parses sampling metadata from a given Stan CSV path for a sample run""" + try: + comments, header, draws = parse_comments_header_and_draws(path) + if header is None: + raise ValueError("No header line found in stan csv") + raise_on_inconsistent_draws_shape(header, draws) + config = construct_config_header_dict(comments, header) + num_warmup, num_sampling = count_warmup_and_sampling_draws(path) + timings = parse_timing_lines(comments) + if ( + (config['algorithm'] != 'fixed_param') + and header + and not is_sneaky_fixed_param(header) + ): + raise_on_invalid_adaptation_block(comments) + max_depth: int = config["max_depth"] # type: ignore + max_tree_hits, divs = extract_max_treedepth_and_divergence_counts( + header, draws, max_depth, num_warmup + ) + else: + max_tree_hits, divs = 0, 0 + except (KeyError, ValueError) as exc: + raise ValueError(f"Error in reading csv file: {path}") from exc + + key_renames = { + "Warm-up": "warmup", + "Sampling": "sampling", + "Total": "total", + } + addtl: Dict[str, Union[int, Dict[str, float]]] = { + "draws_warmup": num_warmup, + "draws_sampling": num_sampling, + "ct_divergences": divs, + "ct_max_treedepth": max_tree_hits, + "time": {key_renames[k]: v for k, v in timings.items()}, + } + return {**config, **addtl} def munge_varname(name: str) -> str: @@ -386,190 +531,6 @@ def munge_varname(name: str) -> str: return '.'.join(tuple_parts) -def munge_varnames(names: List[str]) -> List[str]: - """ - Change formatting for indices of container var elements - from use of dot separator to array-like notation, e.g., - rewrite label ``y_forecast.2.4`` to ``y_forecast[2,4]``. - """ - if names is None: - raise ValueError('missing argument "names"') - return [munge_varname(name) for name in names] - - -def scan_hmc_params( - fd: TextIO, config_dict: Dict[str, Any], lineno: int -) -> int: - """ - Scan step size, metric from stan_csv file comment lines. - """ - metric = config_dict['metric'] - line = fd.readline().strip() - lineno += 1 - if not line == '# Adaptation terminated': - raise ValueError( - 'line {}: expecting metric, found:\n\t "{}"'.format(lineno, line) - ) - line = fd.readline().strip() - lineno += 1 - label, step_size = line.split('=') - if not label.startswith('# Step size'): - raise ValueError( - 'line {}: expecting step size, ' - 'found:\n\t "{}"'.format(lineno, line) - ) - try: - float(step_size.strip()) - except ValueError as e: - raise ValueError( - 'line {}: invalid step size: {}'.format(lineno, step_size) - ) from e - before_metric = fd.tell() - line = fd.readline().strip() - lineno += 1 - if metric == 'unit_e': - if line.startswith("# No free parameters"): - return lineno - else: - fd.seek(before_metric) - return lineno - 1 - - if not ( - ( - metric == 'diag_e' - and line == '# Diagonal elements of inverse mass matrix:' - ) - or ( - metric == 'dense_e' and line == '# Elements of inverse mass matrix:' - ) - ): - raise ValueError( - 'line {}: invalid or missing mass matrix ' - 'specification'.format(lineno) - ) - line = fd.readline().lstrip(' #\t') - lineno += 1 - num_unconstrained_params = len(line.split(',')) - if metric == 'diag_e': - return lineno - else: - for _ in range(1, num_unconstrained_params): - line = fd.readline().lstrip(' #\t') - lineno += 1 - if len(line.split(',')) != num_unconstrained_params: - raise ValueError( - 'line {}: invalid or missing mass matrix ' - 'specification'.format(lineno) - ) - return lineno - - -def scan_sampling_iters( - fd: TextIO, config_dict: Dict[str, Any], lineno: int, is_fixed_param: bool -) -> int: - """ - Parse sampling iteration, save number of iterations to config_dict. - Also save number of divergences, max_treedepth hits - """ - draws_found = 0 - num_cols = len(config_dict['column_names']) - if not is_fixed_param: - idx_divergent = config_dict['column_names'].index('divergent__') - idx_treedepth = config_dict['column_names'].index('treedepth__') - max_treedepth = config_dict['max_depth'] - ct_divergences = 0 - ct_max_treedepth = 0 - - cur_pos = fd.tell() - line = fd.readline().strip() - while len(line) > 0 and not line.startswith('#'): - lineno += 1 - draws_found += 1 - data = line.split(',') - if len(data) != num_cols: - raise ValueError( - 'line {}: bad draw, expecting {} items, found {}\n'.format( - lineno, num_cols, len(line.split(',')) - ) - + 'This error could be caused by running out of disk space.\n' - 'Try clearing up TEMP or setting output_dir to a path' - ' on another drive.', - ) - cur_pos = fd.tell() - line = fd.readline().strip() - if not is_fixed_param: - ct_divergences += int(data[idx_divergent]) # type: ignore - if int(data[idx_treedepth]) == max_treedepth: # type: ignore - ct_max_treedepth += 1 - - fd.seek(cur_pos) - config_dict['draws_sampling'] = draws_found - if not is_fixed_param: - config_dict['ct_divergences'] = ct_divergences - config_dict['ct_max_treedepth'] = ct_max_treedepth - return lineno - - -def scan_time(fd: TextIO, config_dict: Dict[str, Any], lineno: int) -> int: - """ - Scan time information from the trailing comment lines in a Stan CSV file. - - # Elapsed Time: 0.001332 seconds (Warm-up) - # 0.000249 seconds (Sampling) - # 0.001581 seconds (Total) - - - It extracts the time values and saves them in the config_dict: key 'time', - value a dictionary with keys 'warmup', 'sampling', and 'total'. - Returns the updated line number after reading the time info. - - :param fd: Open file descriptor at comment row following all sample data. - :param config_dict: Dictionary to which the time info is added. - :param lineno: Current line number - """ - time = {} - keys = ['warmup', 'sampling', 'total'] - while True: - pos = fd.tell() - line = fd.readline() - if not line: - break - lineno += 1 - stripped = line.strip() - if not stripped.startswith('#'): - fd.seek(pos) - lineno -= 1 - break - content = stripped.lstrip('#').strip() - if not content: - continue - tokens = content.split() - if len(tokens) < 3: - raise ValueError(f"Invalid time at line {lineno}: {content}") - if 'Warm-up' in content: - key = 'warmup' - time_str = tokens[2] - elif 'Sampling' in content: - key = 'sampling' - time_str = tokens[0] - elif 'Total' in content: - key = 'total' - time_str = tokens[0] - else: - raise ValueError(f"Invalid time at line {lineno}: {content}") - try: - t = float(time_str) - except ValueError as e: - raise ValueError(f"Invalid time at line {lineno}: {content}") from e - time[key] = t - - if not all(key in time for key in keys): - raise ValueError(f"Invalid time, stopped at {lineno}") - - config_dict['time'] = time - return lineno - - def read_metric(path: str) -> List[int]: """ Read metric file in JSON or Rdump format. diff --git a/test/test_metadata.py b/test/test_metadata.py index bfaaf6bc..9e875a15 100644 --- a/test/test_metadata.py +++ b/test/test_metadata.py @@ -42,7 +42,6 @@ def test_good() -> None: runset._set_retcode(i, 0) config = check_sampler_csv( path=runset.csv_files[i], - is_fixed_param=False, iter_sampling=100, iter_warmup=1000, save_warmup=False, diff --git a/test/test_stancsv.py b/test/test_stancsv.py index dc6dd0ee..0fcad8fe 100644 --- a/test/test_stancsv.py +++ b/test/test_stancsv.py @@ -1,5 +1,6 @@ """testing stancsv parsing""" +import io import os from pathlib import Path from test import without_import @@ -15,7 +16,7 @@ DATAFILES_PATH = os.path.join(HERE, 'data') -def test_csv_bytes_to_numpy_no_header(): +def test_csv_bytes_to_numpy(): lines = [ b"-6.76206,1,0.787025,1,1,0,6.81411,0.229458\n", b"-6.81411,0.983499,0.787025,1,1,0,6.8147,0.20649\n", @@ -31,12 +32,12 @@ def test_csv_bytes_to_numpy_no_header(): ], dtype=np.float64, ) - arr_out = stancsv.csv_bytes_list_to_numpy(lines, includes_header=False) + arr_out = stancsv.csv_bytes_list_to_numpy(lines) assert np.array_equal(arr_out, expected) assert arr_out[0].dtype == np.float64 -def test_csv_bytes_to_numpy_no_header_no_polars(): +def test_csv_bytes_to_numpy_no_polars(): lines = [ b"-6.76206,1,0.787025,1,1,0,6.81411,0.229458\n", b"-6.81411,0.983499,0.787025,1,1,0,6.8147,0.20649\n", @@ -53,35 +54,11 @@ def test_csv_bytes_to_numpy_no_header_no_polars(): dtype=np.float64, ) with without_import("polars", cmdstanpy.utils.stancsv): - arr_out = stancsv.csv_bytes_list_to_numpy(lines, includes_header=False) + arr_out = stancsv.csv_bytes_list_to_numpy(lines) assert np.array_equal(arr_out, expected) assert arr_out[0].dtype == np.float64 -def test_csv_bytes_to_numpy_with_header(): - lines = [ - ( - b"lp__,accept_stat__,stepsize__,treedepth__," - b"n_leapfrog__,divergent__,energy__,theta\n" - ), - b"-6.76206,1,0.787025,1,1,0,6.81411,0.229458\n", - b"-6.81411,0.983499,0.787025,1,1,0,6.8147,0.20649\n", - b"-6.85511,0.994945,0.787025,2,3,0,6.85536,0.310589\n", - b"-6.85511,0.812189,0.787025,1,1,0,7.16517,0.310589\n", - ] - expected = np.array( - [ - [-6.76206, 1, 0.787025, 1, 1, 0, 6.81411, 0.229458], - [-6.81411, 0.983499, 0.787025, 1, 1, 0, 6.8147, 0.20649], - [-6.85511, 0.994945, 0.787025, 2, 3, 0, 6.85536, 0.310589], - [-6.85511, 0.812189, 0.787025, 1, 1, 0, 7.16517, 0.310589], - ], - dtype=np.float64, - ) - arr_out = stancsv.csv_bytes_list_to_numpy(lines, includes_header=True) - assert np.array_equal(arr_out, expected) - - def test_csv_bytes_to_numpy_single_element(): lines = [ b"-6.76206\n", @@ -92,7 +69,7 @@ def test_csv_bytes_to_numpy_single_element(): ], dtype=np.float64, ) - arr_out = stancsv.csv_bytes_list_to_numpy(lines, includes_header=False) + arr_out = stancsv.csv_bytes_list_to_numpy(lines) assert np.array_equal(arr_out, expected) @@ -107,72 +84,27 @@ def test_csv_bytes_to_numpy_single_element_no_polars(): dtype=np.float64, ) with without_import("polars", cmdstanpy.utils.stancsv): - arr_out = stancsv.csv_bytes_list_to_numpy(lines, includes_header=False) - assert np.array_equal(arr_out, expected) - - -def test_csv_bytes_to_numpy_with_header_no_polars(): - lines = [ - ( - b"lp__,accept_stat__,stepsize__,treedepth__," - b"n_leapfrog__,divergent__,energy__,theta\n" - ), - b"-6.76206,1,0.787025,1,1,0,6.81411,0.229458\n", - b"-6.81411,0.983499,0.787025,1,1,0,6.8147,0.20649\n", - b"-6.85511,0.994945,0.787025,2,3,0,6.85536,0.310589\n", - b"-6.85511,0.812189,0.787025,1,1,0,7.16517,0.310589\n", - ] - expected = np.array( - [ - [-6.76206, 1, 0.787025, 1, 1, 0, 6.81411, 0.229458], - [-6.81411, 0.983499, 0.787025, 1, 1, 0, 6.8147, 0.20649], - [-6.85511, 0.994945, 0.787025, 2, 3, 0, 6.85536, 0.310589], - [-6.85511, 0.812189, 0.787025, 1, 1, 0, 7.16517, 0.310589], - ], - dtype=np.float64, - ) - with without_import("polars", cmdstanpy.utils.stancsv): - arr_out = stancsv.csv_bytes_list_to_numpy(lines, includes_header=True) + arr_out = stancsv.csv_bytes_list_to_numpy(lines) assert np.array_equal(arr_out, expected) def test_csv_bytes_empty(): lines = [] arr = stancsv.csv_bytes_list_to_numpy(lines) - assert np.array_equal(arr, np.empty((0,))) - - -def test_csv_bytes_to_numpy_header_no_draws(): - lines = [ - ( - b"lp__,accept_stat__,stepsize__,treedepth__," - b"n_leapfrog__,divergent__,energy__,theta\n" - ), - ] - arr = stancsv.csv_bytes_list_to_numpy(lines) - assert arr.shape == (0, 8) + assert np.array_equal(arr, np.empty((0, 0))) -def test_csv_bytes_to_numpy_header_no_draws_no_polars(): - lines = [ - ( - b"lp__,accept_stat__,stepsize__,treedepth__," - b"n_leapfrog__,divergent__,energy__,theta\n" - ), - ] - with without_import("polars", cmdstanpy.utils.stancsv): - arr = stancsv.csv_bytes_list_to_numpy(lines) - assert arr.shape == (0, 8) - - -def test_parse_comments_and_draws(): - lines: List[bytes] = [b"# 1\n", b"2\n", b"3\n", b"# 4\n"] - comment_lines, draws_lines = stancsv.parse_stan_csv_comments_and_draws( - iter(lines) - ) +def test_parse_comments_header_and_draws(): + lines: List[bytes] = [b"# 1\n", b"a\n", b"3\n", b"# 4\n"] + ( + comment_lines, + header, + draws_lines, + ) = stancsv.parse_comments_header_and_draws(iter(lines)) assert comment_lines == [b"# 1\n", b"# 4\n"] - assert draws_lines == [b"2\n", b"3\n"] + assert header == "a" + assert draws_lines == [b"3\n"] def test_parsing_adaptation_lines(): @@ -181,6 +113,7 @@ def test_parsing_adaptation_lines(): b"# Step size = 0.787025\n", b"# Diagonal elements of inverse mass matrix:\n", b"# 1\n", + b"# Elapsed Time\n", ] step_size, mass_matrix = stancsv.parse_hmc_adaptation_lines(lines) assert step_size == 0.787025 @@ -225,16 +158,12 @@ def test_parsing_adaptation_lines_dense(): assert np.array_equal(mass_matrix, expected) -def test_parsing_adaptation_lines_missing_step_size(): +def test_parsing_adaptation_lines_missing_everything(): lines = [ b"# Adaptation terminated\n", b"# Elements of inverse mass matrix:\n", - b"# 2.84091, 0.230843, 0.0509365\n", - b"# 0.230843, 3.92459, 0.126989\n", - b"# 0.0509365, 0.126989, 3.82718\n", ] - with pytest.raises(ValueError): - stancsv.parse_hmc_adaptation_lines(lines) + assert stancsv.parse_hmc_adaptation_lines(lines) == (None, None) def test_parsing_adaptation_lines_no_free_params(): @@ -254,13 +183,9 @@ def test_csv_polars_and_numpy_equiv(): b"-6.85511,0.994945,0.787025,2,3,0,6.85536,0.310589\n", b"-6.85511,0.812189,0.787025,1,1,0,7.16517,0.310589\n", ] - arr_out_polars = stancsv.csv_bytes_list_to_numpy( - lines, includes_header=False - ) + arr_out_polars = stancsv.csv_bytes_list_to_numpy(lines) with without_import("polars", cmdstanpy.utils.stancsv): - arr_out_numpy = stancsv.csv_bytes_list_to_numpy( - lines, includes_header=False - ) + arr_out_numpy = stancsv.csv_bytes_list_to_numpy(lines) assert np.array_equal(arr_out_polars, arr_out_numpy) @@ -268,13 +193,9 @@ def test_csv_polars_and_numpy_equiv_one_line(): lines = [ b"-6.76206,1,0.787025,1,1,0,6.81411,0.229458\n", ] - arr_out_polars = stancsv.csv_bytes_list_to_numpy( - lines, includes_header=False - ) + arr_out_polars = stancsv.csv_bytes_list_to_numpy(lines) with without_import("polars", cmdstanpy.utils.stancsv): - arr_out_numpy = stancsv.csv_bytes_list_to_numpy( - lines, includes_header=False - ) + arr_out_numpy = stancsv.csv_bytes_list_to_numpy(lines) assert np.array_equal(arr_out_polars, arr_out_numpy) @@ -282,31 +203,473 @@ def test_csv_polars_and_numpy_equiv_one_element(): lines = [ b"-6.76206\n", ] - arr_out_polars = stancsv.csv_bytes_list_to_numpy( - lines, includes_header=False - ) + arr_out_polars = stancsv.csv_bytes_list_to_numpy(lines) with without_import("polars", cmdstanpy.utils.stancsv): - arr_out_numpy = stancsv.csv_bytes_list_to_numpy( - lines, includes_header=False - ) + arr_out_numpy = stancsv.csv_bytes_list_to_numpy(lines) assert np.array_equal(arr_out_polars, arr_out_numpy) def test_parse_stan_csv_from_file(): csv_path = os.path.join(DATAFILES_PATH, "bernoulli_output_1.csv") - comment_lines, draws_lines = stancsv.parse_stan_csv_comments_and_draws( - csv_path - ) + ( + comment_lines, + header, + draws_lines, + ) = stancsv.parse_comments_header_and_draws(csv_path) assert all(ln.startswith(b"#") for ln in comment_lines) + assert header is not None and not header.startswith("#") assert all(not ln.startswith(b"#") for ln in draws_lines) ( comment_lines_path, + header_path, draws_lines_path, - ) = stancsv.parse_stan_csv_comments_and_draws(Path(csv_path)) - assert all(ln.startswith(b"#") for ln in comment_lines) - assert all(not ln.startswith(b"#") for ln in draws_lines) + ) = stancsv.parse_comments_header_and_draws(Path(csv_path)) + assert all(ln.startswith(b"#") for ln in comment_lines_path) + assert header_path is not None and not header.startswith("#") + assert all(not ln.startswith(b"#") for ln in draws_lines_path) assert comment_lines == comment_lines_path + assert header == header_path assert draws_lines == draws_lines_path + + +def test_config_parsing(): + csv_path = os.path.join(DATAFILES_PATH, "bernoulli_output_1.csv") + + comment_lines, *_ = stancsv.parse_comments_header_and_draws(csv_path) + config = stancsv.parse_config(comment_lines) + + expected = { + 'stan_version_major': 2, + 'stan_version_minor': 19, + 'stan_version_patch': 0, + 'model': 'bernoulli_model', + 'method': 'sample', + 'num_samples': 10, + 'num_warmup': 100, + 'save_warmup': 0, + 'thin': 1, + 'engaged': 1, + 'gamma': 0.05, + 'delta': 0.8, + 'kappa': 0.75, + 't0': 10, + 'init_buffer': 75, + 'term_buffer': 50, + 'window': 25, + 'algorithm': 'hmc', + 'engine': 'nuts', + 'max_depth': 10, + 'metric': 'diag_e', + 'metric_file': '', + 'stepsize': 1, + 'stepsize_jitter': 0, + 'id': 1, + 'data_file': 'examples/bernoulli/bernoulli.data.json', + 'init': 2, + 'seed': 123456, + 'diagnostic_file': '', + 'refresh': 100, + 'Step size': 0.787025, + } + + assert config == expected + + +def test_config_parsing_data_transforms(): + comments = [ + b"# bool_t = true\n", + b"# bool_f = false\n", + b"# float = 1.5\n", + b"# int = 1\n", + ] + expected = {"bool_t": 1, "bool_f": 0, "float": 1.5, "int": 1} + assert stancsv.parse_config(comments) == expected + + +def test_column_filter_basic(): + data = [b"1,2,3\n", b"4,5,6\n"] + indexes = [0, 2] + expected = [b"1,3\n", b"4,6\n"] + assert stancsv.filter_csv_bytes_by_columns(data, indexes) == expected + + +def test_column_filter_empty_input(): + assert not stancsv.filter_csv_bytes_by_columns([], [0]) + + +def test_column_filter_empty_indexes(): + data = [b"1,2,3\n", b"4,5,6\n"] + assert stancsv.filter_csv_bytes_by_columns(data, []) == [b"\n", b"\n"] + + +def test_column_filter_single_column(): + data = [b"a,b,c\n", b"d,e,f\n"] + assert stancsv.filter_csv_bytes_by_columns(data, [1]) == [b"b\n", b"e\n"] + + +def test_column_filter_non_consecutive_indexes(): + data = [b"9,8,7,6\n", b"5,4,3,2\n"] + assert stancsv.filter_csv_bytes_by_columns(data, [2, 0]) == [ + b"7,9\n", + b"3,5\n", + ] + + +def test_parse_header(): + header = ( + "lp__,accept_stat__,stepsize__,treedepth__" + ",n_leapfrog__,divergent__,energy__,theta.1" + ) + parsed = stancsv.parse_header(header) + expected = ( + "lp__", + "accept_stat__", + "stepsize__", + "treedepth__", + "n_leapfrog__", + "divergent__", + "energy__", + "theta[1]", + ) + assert parsed == expected + + +def test_extract_config_and_header_info(): + comments = [b"# stan_version_major = 2\n"] + header = "lp__,theta.1" + out = stancsv.construct_config_header_dict(comments, header) + assert out["stan_version_major"] == 2 + assert out["raw_header"] == "lp__,theta.1" + assert out["column_names"] == ("lp__", "theta[1]") + + +def test_parse_variational_eta(): + csv_path = os.path.join(DATAFILES_PATH, "variational", "eta_big_output.csv") + comments, *_ = stancsv.parse_comments_header_and_draws(csv_path) + eta = stancsv.parse_variational_eta(comments) + assert eta == 100.0 + + +def test_parse_variational_eta_no_block(): + comments = [ + b"# stanc_version = stanc3 v2.28.0\n", + b"# stancflags = \n", + b"lp__,log_p__,log_g__,mu.1,mu.2\n", + b"0,0,0,311.545,532.801\n", + b"0,-186118,-4.74553,311.545,353.503\n", + b"0,-184982,-2.75303,311.545,587.377\n", + ] + + with pytest.raises(ValueError): + stancsv.parse_variational_eta(comments) + + +def test_max_treedepth_and_divergence_counts(): + header = ( + "lp__,accept_stat__,stepsize__,treedepth__," + "n_leapfrog__,divergent__,energy__,theta\n" + ) + draws = [ + b"-4.78686,0.986298,1.09169,1,3,0,5.29492,0.550024\n", + b"-5.07942,0.676947,1.09169,10,3,0,6.44279,0.709113\n", + b"-5.04922,1,1.09169,1,1,0,5.14176,0.702445\n", + b"-5.09338,0.996111,1.09169,10,3,1,5.16083,0.712059\n", + b"-4.78903,0.989798,1.09169,1,3,0,5.08116,0.546685\n", + b"-5.36502,0.854345,1.09169,1,3,0,5.39311,0.369686\n", + b"-5.13605,0.937837,1.09169,1,3,0,5.95811,0.720607\n", + b"-4.80646,1,1.09169,2,3,0,5.0962,0.528418\n", + ] + out = stancsv.extract_max_treedepth_and_divergence_counts( + header, draws, 10, 0 + ) + assert out == (2, 1) + + +def test_max_treedepth_and_divergence_counts_warmup_draws(): + header = ( + "lp__,accept_stat__,stepsize__,treedepth__," + "n_leapfrog__,divergent__,energy__,theta\n" + ) + draws = [ + b"-4.78686,0.986298,1.09169,1,3,0,5.29492,0.550024\n", + b"-5.07942,0.676947,1.09169,10,3,0,6.44279,0.709113\n", + b"-5.04922,1,1.09169,1,1,0,5.14176,0.702445\n", + b"-5.09338,0.996111,1.09169,10,3,1,5.16083,0.712059\n", + b"-4.78903,0.989798,1.09169,1,3,0,5.08116,0.546685\n", + b"-5.36502,0.854345,1.09169,1,3,0,5.39311,0.369686\n", + b"-5.13605,0.937837,1.09169,1,3,0,5.95811,0.720607\n", + b"-4.80646,1,1.09169,2,3,0,5.0962,0.528418\n", + ] + out = stancsv.extract_max_treedepth_and_divergence_counts( + header, draws, 10, 2 + ) + assert out == (1, 1) + + +def test_max_treedepth_and_divergence_counts_no_draws(): + header = ( + "lp__,accept_stat__,stepsize__,treedepth__," + "n_leapfrog__,divergent__,energy__,theta\n" + ) + draws = [] + out = stancsv.extract_max_treedepth_and_divergence_counts( + header, draws, 10, 0 + ) + assert out == (0, 0) + + +def test_max_treedepth_and_divergence_invalid(): + header = "lp__,accept_stat__,stepsize__,n_leapfrog__,energy__,theta\n" + draws = [ + b"-4.78686,0.986298,1.09169,3,5.29492,0.550024\n", + ] + assert stancsv.extract_max_treedepth_and_divergence_counts( + header, draws, 10, 0 + ) == (0, 0) + + +def test_sneaky_fixed_param_check(): + sneaky_header = "lp__,accept_stat__,N,y_sim.1" + normal_header = ( + "lp__,accept_stat__,stepsize__,treedepth__," + "n_leapfrog__,divergent__,energy__,theta" + ) + + assert stancsv.is_sneaky_fixed_param(sneaky_header) + assert not stancsv.is_sneaky_fixed_param(normal_header) + + +def test_warmup_sampling_draw_counts(): + csv_path = os.path.join(DATAFILES_PATH, "bernoulli_output_1.csv") + assert stancsv.count_warmup_and_sampling_draws(csv_path) == (0, 10) + + +def test_warmup_sampling_draw_counts_with_warmup(): + lines = [ + b"# algorithm = hmc (Default)\n", + ( + b"lp__,accept_stat__,stepsize__,treedepth__," + b"n_leapfrog__,divergent__,energy__,theta\n" + ), + b"-6.76206,1,0.787025,1,1,0,6.81411,0.229458\n", + b"# Adaptation terminated\n", + b"# Step size = 0.787025\n", + b"# Diagonal elements of inverse mass matrix:\n", + b"# 1\n", + b"-6.76206,1,0.787025,1,1,0,6.81411,0.229458\n", + b"# \n", + b"# Elapsed Time: 0.001332 seconds (Warm-up)\n", + ] + fio = io.BytesIO(b"".join(lines)) + assert stancsv.count_warmup_and_sampling_draws(fio) == (1, 1) + + +def test_warmup_sampling_draw_counts_fixed_param(): + lines = [ + b"# algorithm = fixed_param\n", + ( + b"lp__,accept_stat__,stepsize__,treedepth__," + b"n_leapfrog__,divergent__,energy__,theta\n" + ), + b"-6.76206,1,0.787025,1,1,0,6.81411,0.229458\n", + b"-6.76206,1,0.787025,1,1,0,6.81411,0.229458\n", + b"# \n", + b"# Elapsed Time: 0.001332 seconds (Warm-up)\n", + ] + fio = io.BytesIO(b"".join(lines)) + assert stancsv.count_warmup_and_sampling_draws(fio) == (0, 2) + + +def test_warmup_sampling_draw_counts_no_draws(): + lines = [ + b"# algorithm = fixed_param\n", + ( + b"lp__,accept_stat__,stepsize__,treedepth__," + b"n_leapfrog__,divergent__,energy__,theta\n" + ), + b"# Elapsed Time: 0.001332 seconds (Warm-up)\n", + b"# 0.001332 seconds (Sampling)\n", + ] + fio = io.BytesIO(b"".join(lines)) + assert stancsv.count_warmup_and_sampling_draws(fio) == (0, 0) + + +def test_warmup_sampling_draw_counts_invalid(): + lines = [ + b"# algorithm = fixed_param\n", + ] + fio = io.BytesIO(b"".join(lines)) + with pytest.raises(ValueError): + stancsv.count_warmup_and_sampling_draws(fio) + + +def test_inconsistent_draws_shape(): + header = "a,b" + draws = [b"0,1,2\n"] + with pytest.raises(ValueError): + stancsv.raise_on_inconsistent_draws_shape(header, draws) + + +def test_inconsistent_draws_shape_empty(): + draws = [] + stancsv.raise_on_inconsistent_draws_shape("", draws) + + +def test_invalid_adaptation_block_good(): + csv_path = os.path.join(DATAFILES_PATH, "bernoulli_output_1.csv") + comments, *_ = stancsv.parse_comments_header_and_draws(csv_path) + stancsv.raise_on_invalid_adaptation_block(comments) + + +def test_invalid_adaptation_block_missing(): + lines = [ + b"# metric = diag_e (Default)\n", + ( + b"lp__,accept_stat__,stepsize__,treedepth__," + b"n_leapfrog__,divergent__,energy__,theta\n" + ), + b"-6.76206,1,0.787025,1,1,0,6.81411,0.229458\n", + b"# \n", + b"# Elapsed Time: 0.001332 seconds (Warm-up)\n", + ] + with pytest.raises(ValueError, match="expecting metric"): + stancsv.raise_on_invalid_adaptation_block(lines) + + +def test_invalid_adaptation_block_no_metric(): + lines = [ + ( + b"lp__,accept_stat__,stepsize__,treedepth__," + b"n_leapfrog__,divergent__,energy__,theta\n" + ), + b"# Adaptation terminated\n", + b"# Step size = 0.787025\n", + b"# Diagonal elements of inverse mass matrix:\n", + b"# 1\n", + ] + with pytest.raises(ValueError, match="No reported metric"): + stancsv.raise_on_invalid_adaptation_block(lines) + + +def test_invalid_adaptation_block_invalid_step_size(): + lines = [ + b"# metric = diag_e (Default)\n", + ( + b"lp__,accept_stat__,stepsize__,treedepth__," + b"n_leapfrog__,divergent__,energy__,theta\n" + ), + b"# Adaptation terminated\n", + b"# Step size = bad\n", + b"# Diagonal elements of inverse mass matrix:\n", + b"# 1\n", + ] + with pytest.raises(ValueError, match="invalid step size"): + stancsv.raise_on_invalid_adaptation_block(lines) + + +def test_invalid_adaptation_block_mismatched_structure(): + lines = [ + b"# metric = diag_e (Default)\n", + ( + b"lp__,accept_stat__,stepsize__,treedepth__," + b"n_leapfrog__,divergent__,energy__,theta\n" + ), + b"# Adaptation terminated\n", + b"# Step size = 0.787025\n", + b"# Elements of inverse mass matrix:\n", + b"# 1\n", + ] + with pytest.raises(ValueError, match="invalid or missing"): + stancsv.raise_on_invalid_adaptation_block(lines) + + +def test_invalid_adaptation_block_missing_step_size(): + lines = [ + b"# metric = diag_e (Default)\n", + ( + b"lp__,accept_stat__,stepsize__,treedepth__," + b"n_leapfrog__,divergent__,energy__,theta\n" + ), + b"# Adaptation terminated\n", + b"# Diagonal elements of inverse mass matrix:\n", + b"# 1\n", + ] + with pytest.raises(ValueError, match="expecting step size"): + stancsv.raise_on_invalid_adaptation_block(lines) + + +def test_invalid_adaptation_block_unit_e(): + lines = [ + b"# metric = unit_e\n", + ( + b"lp__,accept_stat__,stepsize__,treedepth__," + b"n_leapfrog__,divergent__,energy__,theta\n" + ), + b"# Adaptation terminated\n", + b"# Step size = 1.77497\n", + b"# No free parameters for unit metric\n", + ] + stancsv.raise_on_invalid_adaptation_block(lines) + + +def test_invalid_adaptation_block_dense_e_valid(): + lines = [ + b"# metric = dense_e\n", + ( + b"lp__,accept_stat__,stepsize__,treedepth__," + b"n_leapfrog__,divergent__,energy__,theta.1,theta.2,theta.3\n" + ), + b"# Adaptation terminated\n", + b"# Step size = 0.775147\n", + b"# Elements of inverse mass matrix:\n", + b"# 2.84091, 0.230843, 0.0509365\n", + b"# 0.230843, 3.92459, 0.126989\n", + b"# 0.0509365, 0.126989, 3.82718\n", + ] + stancsv.raise_on_invalid_adaptation_block(lines) + + +def test_invalid_adaptation_block_dense_e_invalid(): + lines = [ + b"# metric = dense_e\n", + ( + b"lp__,accept_stat__,stepsize__,treedepth__," + b"n_leapfrog__,divergent__,energy__,theta.1,theta.2,theta.3\n" + ), + b"# Adaptation terminated\n", + b"# Step size = 0.775147\n", + b"# Elements of inverse mass matrix:\n", + b"# 2.84091, 0.230843, 0.0509365\n", + b"# 2.84091, 0.230843\n", + b"# 0.230843, 3.92459\n", + ] + with pytest.raises(ValueError, match="invalid or missing"): + stancsv.raise_on_invalid_adaptation_block(lines) + + +def test_parsing_timing_lines(): + lines = [ + b"# \n", + b"# Elapsed Time: 0.001332 seconds (Warm-up)\n", + b"# 0.000249 seconds (Sampling)\n", + b"# 0.001581 seconds (Total)\n", + b"# \n", + ] + out = stancsv.parse_timing_lines(lines) + + assert len(out) == 3 + assert out['Warm-up'] == 0.001332 + assert out['Sampling'] == 0.000249 + assert out['Total'] == 0.001581 + + +def test_munge_varname(): + name1 = "a" + name2 = "a:1" + name3 = "a:1.2" + assert stancsv.munge_varname(name1) == "a" + assert stancsv.munge_varname(name2) == "a.1" + assert stancsv.munge_varname(name3) == "a.1[2]" diff --git a/test/test_utils.py b/test/test_utils.py index 27e5db53..5a4d6167 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -294,7 +294,6 @@ def test_check_sampler_csv_1() -> None: csv_good = os.path.join(DATAFILES_PATH, 'bernoulli_output_1.csv') dict = check_sampler_csv( path=csv_good, - is_fixed_param=False, iter_warmup=100, iter_sampling=10, thin=1, @@ -384,7 +383,6 @@ def test_check_sampler_csv_thin() -> None: csv_file = bern_fit.runset.csv_files[0] dict = check_sampler_csv( path=csv_file, - is_fixed_param=False, iter_sampling=490, iter_warmup=490, thin=7, @@ -399,7 +397,6 @@ def test_check_sampler_csv_thin() -> None: with raises_nested(ValueError, 'config error'): check_sampler_csv( path=csv_file, - is_fixed_param=False, iter_sampling=490, iter_warmup=490, thin=9, @@ -407,7 +404,6 @@ def test_check_sampler_csv_thin() -> None: with raises_nested(ValueError, 'expected 490 draws, found 70'): check_sampler_csv( path=csv_file, - is_fixed_param=False, iter_sampling=490, iter_warmup=490, ) @@ -702,57 +698,3 @@ def test_munge_varnames() -> None: var = 'y.2.3:1.2:5:6' assert stancsv.munge_varname(var) == 'y[2,3].1[2].5.6' - - -def test_scan_time_normal() -> None: - csv_content = ( - "# Elapsed Time: 0.005 seconds (Warm-up)\n" - "# 0 seconds (Sampling)\n" - "# 0.005 seconds (Total)\n" - ) - fd = io.StringIO(csv_content) - config_dict = {} - start_line = 0 - final_line = stancsv.scan_time(fd, config_dict, start_line) - assert final_line == 3 - expected = {'warmup': 0.005, 'sampling': 0.0, 'total': 0.005} - assert config_dict.get('time') == expected - - -def test_scan_time_no_timing() -> None: - csv_content = ( - "# merrily we roll along\n" - "# roll along\n" - "# very merrily we roll along\n" - ) - fd = io.StringIO(csv_content) - config_dict = {} - start_line = 0 - with pytest.raises(ValueError, match="Invalid time"): - stancsv.scan_time(fd, config_dict, start_line) - - -def test_scan_time_invalid_value() -> None: - csv_content = ( - "# Elapsed Time: abc seconds (Warm-up)\n" - "# 0.200 seconds (Sampling)\n" - "# 0.300 seconds (Total)\n" - ) - fd = io.StringIO(csv_content) - config_dict = {} - start_line = 0 - with pytest.raises(ValueError, match="Invalid time"): - stancsv.scan_time(fd, config_dict, start_line) - - -def test_scan_time_invalid_string() -> None: - csv_content = ( - "# Elapsed Time: 0.22 seconds (foo)\n" - "# 0.200 seconds (Sampling)\n" - "# 0.300 seconds (Total)\n" - ) - fd = io.StringIO(csv_content) - config_dict = {} - start_line = 0 - with pytest.raises(ValueError, match="Invalid time"): - stancsv.scan_time(fd, config_dict, start_line)