diff --git a/cmdstanpy/cmdstan_args.py b/cmdstanpy/cmdstan_args.py index 80441bae..77450d70 100644 --- a/cmdstanpy/cmdstan_args.py +++ b/cmdstanpy/cmdstan_args.py @@ -10,15 +10,7 @@ import numpy as np from numpy.random import default_rng -from cmdstanpy import _TMPDIR -from cmdstanpy.utils import ( - cmdstan_path, - cmdstan_version_before, - create_named_text_file, - get_logger, - read_metric, - write_stan_json, -) +from cmdstanpy.utils import cmdstan_path, cmdstan_version_before, get_logger OptionalPath = Union[str, os.PathLike, None] @@ -65,9 +57,8 @@ def __init__( save_warmup: bool = False, thin: Optional[int] = None, max_treedepth: Optional[int] = None, - metric: Union[ - str, dict[str, Any], list[str], list[dict[str, Any]], None - ] = None, + metric_type: Optional[str] = None, + metric_file: Union[str, list[str], None] = None, step_size: Union[float, list[float], None] = None, adapt_engaged: bool = True, adapt_delta: Optional[float] = None, @@ -83,9 +74,8 @@ def __init__( self.save_warmup = save_warmup self.thin = thin self.max_treedepth = max_treedepth - self.metric = metric - self.metric_type: Optional[str] = None - self.metric_file: Union[str, list[str], None] = None + self.metric_type: Optional[str] = metric_type + self.metric_file: Union[str, list[str], None] = metric_file self.step_size = step_size self.adapt_engaged = adapt_engaged self.adapt_delta = adapt_delta @@ -178,124 +168,15 @@ def validate(self, chains: Optional[int]) -> None: 'Argument "step_size" must be > 0, ' 'chain {}, found {}.'.format(i + 1, step_size) ) - if self.metric is not None: - if isinstance(self.metric, str): - if self.metric in ['diag', 'diag_e']: - self.metric_type = 'diag_e' - elif self.metric in ['dense', 'dense_e']: - self.metric_type = 'dense_e' - elif self.metric in ['unit', 'unit_e']: - self.metric_type = 'unit_e' - else: - if not os.path.exists(self.metric): - raise ValueError('no such file {}'.format(self.metric)) - dims = read_metric(self.metric) - if len(dims) == 1: - self.metric_type = 'diag_e' - else: - self.metric_type = 'dense_e' - self.metric_file = self.metric - elif isinstance(self.metric, dict): - if 'inv_metric' not in self.metric: - raise ValueError( - 'Entry "inv_metric" not found in metric dict.' - ) - dims = list(np.asarray(self.metric['inv_metric']).shape) - if len(dims) == 1: - self.metric_type = 'diag_e' - else: - self.metric_type = 'dense_e' - dict_file = create_named_text_file( - dir=_TMPDIR, prefix="metric", suffix=".json" - ) - write_stan_json(dict_file, self.metric) - self.metric_file = dict_file - elif isinstance(self.metric, (list, tuple)): - if len(self.metric) != chains: - raise ValueError( - 'Number of metric files must match number of chains,' - ' found {} metric files for {} chains.'.format( - len(self.metric), chains - ) - ) - if all(isinstance(elem, dict) for elem in self.metric): - metric_files: list[str] = [] - for i, metric in enumerate(self.metric): - metric_dict: dict[str, Any] = metric # type: ignore - if 'inv_metric' not in metric_dict: - raise ValueError( - 'Entry "inv_metric" not found in metric dict ' - 'for chain {}.'.format(i + 1) - ) - if i == 0: - dims = list( - np.asarray(metric_dict['inv_metric']).shape - ) - else: - dims2 = list( - np.asarray(metric_dict['inv_metric']).shape - ) - if dims != dims2: - raise ValueError( - 'Found inconsistent "inv_metric" entry ' - 'for chain {}: entry has dims ' - '{}, expected {}.'.format( - i + 1, dims, dims2 - ) - ) - dict_file = create_named_text_file( - dir=_TMPDIR, prefix="metric", suffix=".json" - ) - write_stan_json(dict_file, metric_dict) - metric_files.append(dict_file) - if len(dims) == 1: - self.metric_type = 'diag_e' - else: - self.metric_type = 'dense_e' - self.metric_file = metric_files - elif all(isinstance(elem, str) for elem in self.metric): - metric_files = [] - for i, metric in enumerate(self.metric): - assert isinstance(metric, str) # typecheck - if not os.path.exists(metric): - raise ValueError('no such file {}'.format(metric)) - if i == 0: - dims = read_metric(metric) - else: - dims2 = read_metric(metric) - if len(dims) != len(dims2): - raise ValueError( - 'Metrics files {}, {},' - ' inconsistent metrics'.format( - self.metric[0], metric - ) - ) - if dims != dims2: - raise ValueError( - 'Metrics files {}, {},' - ' inconsistent metrics'.format( - self.metric[0], metric - ) - ) - metric_files.append(metric) - if len(dims) == 1: - self.metric_type = 'diag_e' - else: - self.metric_type = 'dense_e' - self.metric_file = metric_files - else: - raise ValueError( - 'Argument "metric" must be a list of pathnames or ' - 'Python dicts, found list of {}.'.format( - type(self.metric[0]) - ) - ) - else: + if self.metric_type is not None: + if self.metric_type in ['diag', 'dense', 'unit']: + self.metric_type += '_e' + if self.metric_type not in ['diag_e', 'dense_e', 'unit_e']: raise ValueError( - 'Invalid metric specified, not a recognized metric type, ' - 'must be either a metric type name, a filepath, dict, ' - 'or list of per-chain filepaths or dicts. Found ' - 'an object of type {}.'.format(type(self.metric)) + 'Argument "metric" must be one of [diag, dense, unit,' + ' diag_e, dense_e, unit_e], found {}.'.format( + self.metric_type + ) ) if self.adapt_delta is not None: @@ -332,7 +213,8 @@ def validate(self, chains: Optional[int]) -> None: if self.fixed_param and ( self.max_treedepth is not None - or self.metric is not None + or self.metric_type is not None + or self.metric_file is not None or self.step_size is not None or not ( self.adapt_delta is None @@ -371,7 +253,7 @@ def compose(self, idx: int, cmd: list[str]) -> list[str]: cmd.append(f'stepsize={self.step_size}') else: cmd.append(f'stepsize={self.step_size[idx]}') - if self.metric is not None: + if self.metric_type is not None: cmd.append(f'metric={self.metric_type}') if self.metric_file is not None: if not isinstance(self.metric_file, list): diff --git a/cmdstanpy/model.py b/cmdstanpy/model.py index 893c352f..3335d216 100644 --- a/cmdstanpy/model.py +++ b/cmdstanpy/model.py @@ -24,6 +24,7 @@ Union, ) +import numpy as np import pandas as pd from tqdm.auto import tqdm @@ -55,7 +56,12 @@ get_logger, returncode_msg, ) -from cmdstanpy.utils.filesystem import temp_inits, temp_single_json +from cmdstanpy.utils.filesystem import ( + temp_inits, + temp_metrics, + temp_single_json, +) +from cmdstanpy.utils.stancsv import try_deduce_metric_type from . import progress as progbar @@ -697,6 +703,13 @@ def sample( timeout: Optional[float] = None, *, force_one_process_per_chain: Optional[bool] = None, + inv_metric: Union[ + str, + np.ndarray, + Mapping[str, Any], + list[Union[str, np.ndarray, Mapping[str, Any]]], + None, + ] = None, ) -> CmdStanMCMC: """ Run or more chains of the NUTS-HMC sampler to produce a set of draws @@ -785,29 +798,25 @@ def sample( :param max_treedepth: Maximum depth of trees evaluated by NUTS sampler per iteration. - :param metric: Specification of the mass matrix, either as a - vector consisting of the diagonal elements of the covariance - matrix ('diag' or 'diag_e') or the full covariance matrix - ('dense' or 'dense_e'). - - If the value of the metric argument is a string other than - 'diag', 'diag_e', 'dense', or 'dense_e', it must be - a valid filepath to a JSON or Rdump file which contains an entry - 'inv_metric' whose value is either the diagonal vector or - the full covariance matrix. - - If the value of the metric argument is a list of paths, its - length must match the number of chains and all paths must be - unique. - - If the value of the metric argument is a Python dict object, it - must contain an entry 'inv_metric' which specifies either the - diagnoal or dense matrix. - - If the value of the metric argument is a list of Python dicts, - its length must match the number of chains and all dicts must - containan entry 'inv_metric' and all 'inv_metric' entries must - have the same shape. + :param metric: Specify the type of the inverse mass matrix. Options are + 'diag' or 'diag_e' for diagonal matrix, 'dense' or 'dense_e' + for a dense matrix, or 'unit_e' an identity mass matrix. To provide + an initial value for the inverse mass matrix, use the ``inv_metric`` + argument. + + :param inv_metric: Provide an initial value for the inverse + mass matrix. + + Valid options include: + - a string, which must be a valid filepath to a JSON or + Rdump file which contains an entry 'inv_metric' whose value + is either a diagonal vector or dense matrix. + - a numpy array containing either the diagonal vector or dense + matrix. + - a dictionary containing an entry 'inv_metric' whose value + is either a diagonal vector or dense matrix. + - a list of any of the above, of length num_chains, with + the same shape of metric in each entry. :param step_size: Initial step size for HMC sampler. The value is either a single number or a list of numbers which will be used @@ -1001,35 +1010,79 @@ def sample( 'Chain_id must be a non-negative integer value,' ' found {}.'.format(chain_id) ) + if metric is not None and metric not in ( + 'diag', + 'dense', + 'unit_e', + 'diag_e', + 'dense_e', + ): + get_logger().warning( + "Providing anything other than metric type for" + " 'metric' is deprecated and will be removed" + " in the next major release." + " Please provide such information via" + " 'inv_metric' argument." + ) + if inv_metric is not None: + raise ValueError( + "Cannot provide both (deprecated) non-metric-type 'metric'" + " argument and 'inv_metric' argument." + ) + inv_metric = metric # type: ignore # for backwards compatibility + metric = None - sampler_args = SamplerArgs( - num_chains=1 if one_process_per_chain else chains, - iter_warmup=iter_warmup, - iter_sampling=iter_sampling, - save_warmup=save_warmup, - thin=thin, - max_treedepth=max_treedepth, - metric=metric, - step_size=step_size, - adapt_engaged=adapt_engaged, - adapt_delta=adapt_delta, - adapt_init_phase=adapt_init_phase, - adapt_metric_window=adapt_metric_window, - adapt_step_size=adapt_step_size, - fixed_param=fixed_param, - ) + if metric is None and inv_metric is not None: + metric = try_deduce_metric_type(inv_metric) + + if isinstance(inv_metric, list): + if not len(inv_metric) == chains: + raise ValueError( + 'Number of metric files must match number of chains,' + ' found {} metric files for {} chains.'.format( + len(inv_metric), chains + ) + ) with ( temp_single_json(data) as _data, temp_inits(inits, id=chain_ids[0]) as _inits, + temp_metrics(inv_metric, id=chain_ids[0]) as _inv_metric, ): cmdstan_inits: Union[str, list[str], int, float, None] + cmdstan_metrics: Union[str, list[str], None] + if one_process_per_chain and isinstance(inits, list): # legacy cmdstan_inits = [ f"{_inits[:-5]}_{i}.json" for i in chain_ids # type: ignore ] else: cmdstan_inits = _inits + if one_process_per_chain and isinstance(inv_metric, list): # legacy + cmdstan_metrics = [ + f"{_inv_metric[:-5]}_{i}.json" # type: ignore + for i in chain_ids + ] + else: + cmdstan_metrics = _inv_metric + + sampler_args = SamplerArgs( + num_chains=1 if one_process_per_chain else chains, + iter_warmup=iter_warmup, + iter_sampling=iter_sampling, + save_warmup=save_warmup, + thin=thin, + max_treedepth=max_treedepth, + metric_type=metric, # type: ignore + metric_file=cmdstan_metrics, + step_size=step_size, + adapt_engaged=adapt_engaged, + adapt_delta=adapt_delta, + adapt_init_phase=adapt_init_phase, + adapt_metric_window=adapt_metric_window, + adapt_step_size=adapt_step_size, + fixed_param=fixed_param, + ) args = CmdStanArgs( self._name, diff --git a/cmdstanpy/stanfit/mcmc.py b/cmdstanpy/stanfit/mcmc.py index a2d615e5..c71b28e1 100644 --- a/cmdstanpy/stanfit/mcmc.py +++ b/cmdstanpy/stanfit/mcmc.py @@ -223,19 +223,27 @@ def metric_type(self) -> Optional[str]: else None ) + # TODO(2.0): remove @property def metric(self) -> Optional[np.ndarray]: + """Deprecated. Use ``.inv_metric`` instead.""" + get_logger().warning( + 'The "metric" property is deprecated, use "inv_metric" instead. ' + 'This will be the same quantity, but with a more accurate name.' + ) + return self.inv_metric + + @property + def inv_metric(self) -> Optional[np.ndarray]: """ - Metric used by sampler for each chain. - When sampler algorithm 'fixed_param' is specified, metric is None. + Inverse mass matrix used by sampler for each chain. + Returns a ``nchains x nparams`` array when metric_type is 'diag_e', + a ``nchains x nparams x nparams`` array when metric_type is 'dense_e', + or ``None`` when metric_type is 'unit_e' or algorithm is 'fixed_param'. """ - if self._is_fixed_param: - return None - if self._metadata.cmdstan_config['metric'] == 'unit_e': - get_logger().info( - 'Unit diagnonal metric, inverse mass matrix size unknown.' - ) + if self._is_fixed_param or self.metric_type == 'unit_e': return None + self._assemble_draws() return self._metric diff --git a/cmdstanpy/utils/filesystem.py b/cmdstanpy/utils/filesystem.py index e8b361a7..1f293a05 100644 --- a/cmdstanpy/utils/filesystem.py +++ b/cmdstanpy/utils/filesystem.py @@ -10,6 +10,8 @@ import tempfile from typing import Any, Iterator, Mapping, Optional, Union +import numpy as np + from cmdstanpy import _TMPDIR from .json import write_stan_json @@ -165,6 +167,36 @@ def _temp_multiinput( yield from _temp_single_json(input) +@contextlib.contextmanager +def temp_metrics( + metrics: Union[ + str, os.PathLike, Mapping[str, Any], np.ndarray, list[Any], None + ], + *, + id: int = 1, +) -> Iterator[Union[str, None]]: + if isinstance(metrics, dict): + if 'inv_metric' not in metrics: + raise ValueError('Entry "inv_metric" not found in metric dict.') + if isinstance(metrics, np.ndarray): + metrics = {"inv_metric": metrics} + + if isinstance(metrics, list): + metrics_processed = [] + for init in metrics: + if isinstance(init, np.ndarray): + metrics_processed.append({"inv_metric": init}) + else: + metrics_processed.append(init) + if isinstance(metrics_processed, dict): + if 'inv_metric' not in metrics_processed: + raise ValueError( + 'Entry "inv_metric" not found in metric dict.' + ) + metrics = metrics_processed + yield from _temp_multiinput(metrics, base=id) + + @contextlib.contextmanager def temp_inits( inits: Union[ diff --git a/cmdstanpy/utils/stancsv.py b/cmdstanpy/utils/stancsv.py index 32f01d5f..aa8760d1 100644 --- a/cmdstanpy/utils/stancsv.py +++ b/cmdstanpy/utils/stancsv.py @@ -8,7 +8,7 @@ import os import re import warnings -from typing import Any, Iterator, Optional, Union +from typing import Any, Iterator, Mapping, Optional, Union import numpy as np import numpy.typing as npt @@ -631,3 +631,40 @@ def parse_rdump_value(rhs: str) -> Union[int, float, np.ndarray]: except TypeError as e: raise ValueError('bad value in Rdump file: {}'.format(rhs)) from e return val + + +def try_deduce_metric_type( + inv_metric: Union[ + str, + np.ndarray, + Mapping[str, Any], + list[Union[str, np.ndarray, Mapping[str, Any]]], + ], +) -> Optional[str]: + """Given a user-supplied metric, try to infer the correct metric type.""" + if isinstance(inv_metric, list): + if inv_metric: + inv_metric = inv_metric[0] + + if isinstance(inv_metric, Mapping): + if (metric_type := inv_metric.get("metric_type")) in ( + 'diag_e', + 'dense_e', + ): + return metric_type # type: ignore + inv_metric = inv_metric.get('inv_metric', None) + + if isinstance(inv_metric, np.ndarray): + if len(inv_metric.shape) == 1: + return 'diag_e' + else: + return 'dense_e' + + if isinstance(inv_metric, str): + dims = read_metric(inv_metric) + if len(dims) == 1: + return 'diag_e' + else: + return 'dense_e' + + return None diff --git a/cmdstanpy_tutorial.ipynb b/cmdstanpy_tutorial.ipynb index b88bb633..38401605 100644 --- a/cmdstanpy_tutorial.ipynb +++ b/cmdstanpy_tutorial.ipynb @@ -352,7 +352,7 @@ "metadata": {}, "outputs": [], "source": [ - "fit.metric_type, fit.metric" + "fit.metric_type, fit.inv_metric" ] }, { diff --git a/cmdstanpy_tutorial.py b/cmdstanpy_tutorial.py index e7f052f6..83a7b7a7 100644 --- a/cmdstanpy_tutorial.py +++ b/cmdstanpy_tutorial.py @@ -37,7 +37,7 @@ print(fit.step_size) print(fit.metric_type) -print(fit.metric) +print(fit.inv_metric) # #### Summarize the results diff --git a/docsrc/users-guide/examples/MCMC Sampling.ipynb b/docsrc/users-guide/examples/MCMC Sampling.ipynb index 3891fe46..053faa94 100644 --- a/docsrc/users-guide/examples/MCMC Sampling.ipynb +++ b/docsrc/users-guide/examples/MCMC Sampling.ipynb @@ -1483,7 +1483,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1502,7 +1502,7 @@ } ], "source": [ - "print(f'adapted step_size per chain\\n{fit.step_size}\\nmetric_type: {fit.metric_type}\\nmetric:\\n{fit.metric}')" + "print(f'adapted step_size per chain\\n{fit.step_size}\\nmetric_type: {fit.metric_type}\\ninverse metric:\\n{fit.inv_metric}')" ] }, { diff --git a/docsrc/users-guide/hello_world.rst b/docsrc/users-guide/hello_world.rst index 0fd88cc7..bc6d4f7b 100644 --- a/docsrc/users-guide/hello_world.rst +++ b/docsrc/users-guide/hello_world.rst @@ -171,7 +171,7 @@ access to the the per-chain HMC tuning parameters from the NUTS-HMC adaptive sam .. ipython:: python print(fit.metric_type) - print(fit.metric) + print(fit.inv_metric) print(fit.step_size) diff --git a/test/test_cmdstan_args.py b/test/test_cmdstan_args.py index 7587564d..14832374 100644 --- a/test/test_cmdstan_args.py +++ b/test/test_cmdstan_args.py @@ -151,7 +151,7 @@ def test_bad() -> None: with pytest.raises(ValueError): args.validate(chains=2) - args = SamplerArgs(metric='dense', fixed_param=True) + args = SamplerArgs(metric_type='dense', fixed_param=True) with pytest.raises(ValueError): args.validate(chains=2) @@ -221,22 +221,22 @@ def test_adapt() -> None: def test_metric() -> None: - args = SamplerArgs(metric='dense_e') + args = SamplerArgs(metric_type='dense_e') args.validate(chains=4) cmd = args.compose(1, cmd=[]) assert 'method=sample algorithm=hmc metric=dense_e' in ' '.join(cmd) - args = SamplerArgs(metric='dense') + args = SamplerArgs(metric_type='dense') args.validate(chains=4) cmd = args.compose(1, cmd=[]) assert 'method=sample algorithm=hmc metric=dense_e' in ' '.join(cmd) - args = SamplerArgs(metric='diag_e') + args = SamplerArgs(metric_type='diag_e') args.validate(chains=4) cmd = args.compose(1, cmd=[]) assert 'method=sample algorithm=hmc metric=diag_e' in ' '.join(cmd) - args = SamplerArgs(metric='diag') + args = SamplerArgs(metric_type='diag') args.validate(chains=4) cmd = args.compose(1, cmd=[]) assert 'method=sample algorithm=hmc metric=diag_e' in ' '.join(cmd) @@ -247,29 +247,20 @@ def test_metric() -> None: assert 'metric=' not in ' '.join(cmd) jmetric = os.path.join(DATAFILES_PATH, 'bernoulli.metric.json') - args = SamplerArgs(metric=jmetric) + args = SamplerArgs(metric_file=jmetric) args.validate(chains=4) cmd = args.compose(1, cmd=[]) - assert 'metric=diag_e' in ' '.join(cmd) assert 'metric_file=' in ' '.join(cmd) assert 'bernoulli.metric.json' in ' '.join(cmd) jmetric2 = os.path.join(DATAFILES_PATH, 'bernoulli.metric-2.json') - args = SamplerArgs(metric=[jmetric, jmetric2]) + args = SamplerArgs(metric_file=[jmetric, jmetric2]) args.validate(chains=2) cmd = args.compose(0, cmd=[]) assert 'bernoulli.metric.json' in ' '.join(cmd) cmd = args.compose(1, cmd=[]) assert 'bernoulli.metric-2.json' in ' '.join(cmd) - args = SamplerArgs(metric=[jmetric, jmetric2]) - with pytest.raises(ValueError): - args.validate(chains=4) - - args = SamplerArgs(metric='/no/such/path/to.file') - with pytest.raises(ValueError): - args.validate(chains=4) - def test_fixed_param() -> None: args = SamplerArgs(fixed_param=True) diff --git a/test/test_sample.py b/test/test_sample.py index 7f4f531f..026b9144 100644 --- a/test/test_sample.py +++ b/test/test_sample.py @@ -89,7 +89,7 @@ def test_bernoulli_good(stanfile: str): assert bern_fit.draws().shape == (100, 2, len(BERNOULLI_COLS)) assert bern_fit.metric_type == 'diag_e' assert bern_fit.step_size.shape == (2,) - assert bern_fit.metric.shape == (2, 1) + assert bern_fit.inv_metric.shape == (2, 1) assert bern_fit.draws(concat_chains=True).shape == ( 200, @@ -125,7 +125,7 @@ def test_bernoulli_good(stanfile: str): assert bern_sample.shape == (100, 2, len(BERNOULLI_COLS)) assert bern_fit.metric_type == 'dense_e' assert bern_fit.step_size.shape == (2,) - assert bern_fit.metric.shape == (2, 1, 1) + assert bern_fit.inv_metric.shape == (2, 1, 1) bern_fit = bern_model.sample( data=jdata, @@ -186,9 +186,7 @@ def test_bernoulli_good(stanfile: str): @pytest.mark.parametrize("stanfile", ["bernoulli.stan"]) -def test_bernoulli_unit_e( - stanfile: str, caplog: pytest.LogCaptureFixture -) -> None: +def test_bernoulli_unit_e(stanfile: str) -> None: stan = os.path.join(DATAFILES_PATH, stanfile) bern_model = CmdStanModel(stan_file=stan) @@ -204,19 +202,9 @@ def test_bernoulli_unit_e( show_progress=False, ) assert bern_fit.metric_type == 'unit_e' - assert bern_fit.metric is None + assert bern_fit.inv_metric is None assert bern_fit.step_size.shape == (2,) - with caplog.at_level(logging.INFO): - logging.getLogger() - assert bern_fit.metric is None - check_present( - caplog, - ( - 'cmdstanpy', - 'INFO', - 'Unit diagnonal metric, inverse mass matrix size unknown.', - ), - ) + assert bern_fit.draws().shape == (100, 2, len(BERNOULLI_COLS)) @@ -535,7 +523,7 @@ def test_fixed_param_good() -> None: ) assert datagen_fit.runset._args.method == Method.SAMPLE assert datagen_fit.metric_type is None - assert datagen_fit.metric is None + assert datagen_fit.inv_metric is None assert datagen_fit.step_size is None assert datagen_fit.divergences is None assert datagen_fit.max_treedepths is None @@ -638,7 +626,7 @@ def test_fixed_param_good() -> None: assert datagen_fit.column_names == tuple(column_names) assert datagen_fit.num_draws_sampling == 100 assert datagen_fit.draws().shape == (100, 1, len(column_names)) - assert datagen_fit.metric is None + assert datagen_fit.inv_metric is None assert datagen_fit.metric_type is None assert datagen_fit.step_size is None @@ -860,7 +848,7 @@ def test_validate_big_run() -> None: assert fit.column_names == tuple(column_names) assert fit.metric_type == 'diag_e' assert fit.step_size.shape == (2,) - assert fit.metric.shape == (2, 2095) + assert fit.inv_metric.shape == (2, 2095) assert fit.draws().shape == (1000, 2, 2102) assert fit.draws_pd(vars=['phi']).shape == (2000, 2095) with raises_nested(ValueError, r'Unknown variable: gamma'): @@ -1003,53 +991,99 @@ def test_from_csv_no_param_hmc() -> None: assert no_parameters_sample.draws_pd().shape == (100, 93) -def test_custom_metric() -> None: +@pytest.mark.parametrize('force_one_process_per_chain', [True, False]) +def test_custom_metric(force_one_process_per_chain: bool) -> None: stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') bern_model = CmdStanModel(stan_file=stan) jmetric = os.path.join(DATAFILES_PATH, 'bernoulli.metric.json') + jmetric2 = os.path.join(DATAFILES_PATH, 'bernoulli.metric-2.json') + # read json in as dict + with open(jmetric) as fd: + metric_dict_1 = json.load(fd) + with open(jmetric2) as fd: + metric_dict_2 = json.load(fd) # just test that it runs without error - bern_model.sample( + fit1 = bern_model.sample( data=jdata, chains=2, parallel_chains=2, seed=12345, - iter_warmup=100, - iter_sampling=200, - metric=jmetric, + iter_warmup=10, + iter_sampling=10, + inv_metric=jmetric, + force_one_process_per_chain=force_one_process_per_chain, ) - jmetric2 = os.path.join(DATAFILES_PATH, 'bernoulli.metric-2.json') - bern_model.sample( + np.testing.assert_allclose( + fit1.inv_metric[0], metric_dict_1['inv_metric'], atol=1e-6 + ) + np.testing.assert_allclose( + fit1.inv_metric[1], metric_dict_1['inv_metric'], atol=1e-6 + ) + + fit2 = bern_model.sample( data=jdata, chains=2, parallel_chains=2, seed=12345, - iter_warmup=100, - iter_sampling=200, - metric=[jmetric, jmetric2], + iter_warmup=10, + iter_sampling=10, + inv_metric=[jmetric, jmetric2], + force_one_process_per_chain=force_one_process_per_chain, ) - # read json in as dict - with open(jmetric) as fd: - metric_dict_1 = json.load(fd) - with open(jmetric2) as fd: - metric_dict_2 = json.load(fd) - bern_model.sample( + np.testing.assert_allclose( + fit2.inv_metric[0], metric_dict_1['inv_metric'], atol=1e-6 + ) + np.testing.assert_allclose( + fit2.inv_metric[1], metric_dict_2['inv_metric'], atol=1e-6 + ) + + fit3 = bern_model.sample( data=jdata, chains=4, parallel_chains=2, seed=12345, - iter_warmup=100, - iter_sampling=200, - metric=metric_dict_1, + iter_warmup=10, + iter_sampling=10, + inv_metric=metric_dict_1, + force_one_process_per_chain=force_one_process_per_chain, ) - bern_model.sample( + for i in range(4): + np.testing.assert_allclose( + fit3.inv_metric[i], metric_dict_1['inv_metric'], atol=1e-6 + ) + fit4 = bern_model.sample( data=jdata, chains=2, seed=12345, - iter_warmup=100, - iter_sampling=200, - metric=[metric_dict_1, metric_dict_2], + iter_warmup=10, + iter_sampling=10, + inv_metric=[metric_dict_1, metric_dict_2], + force_one_process_per_chain=force_one_process_per_chain, + ) + np.testing.assert_allclose( + fit4.inv_metric[0], metric_dict_1['inv_metric'], atol=1e-6 + ) + np.testing.assert_allclose( + fit4.inv_metric[1], metric_dict_2['inv_metric'], atol=1e-6 ) + + fit5 = bern_model.sample( + data=jdata, + chains=2, + seed=12345, + iter_warmup=10, + iter_sampling=10, + inv_metric=[np.array(metric_dict_1['inv_metric']), jmetric2], + force_one_process_per_chain=force_one_process_per_chain, + ) + np.testing.assert_allclose( + fit5.inv_metric[0], metric_dict_1['inv_metric'], atol=1e-6 + ) + np.testing.assert_allclose( + fit5.inv_metric[1], metric_dict_2['inv_metric'], atol=1e-6 + ) + with pytest.raises( ValueError, match='Number of metric files must match number of chains,', @@ -1059,25 +1093,25 @@ def test_custom_metric() -> None: chains=4, parallel_chains=2, seed=12345, - iter_warmup=100, - iter_sampling=200, - metric=[metric_dict_1, metric_dict_2], + iter_warmup=10, + iter_sampling=10, + inv_metric=[metric_dict_1, metric_dict_2], + force_one_process_per_chain=force_one_process_per_chain, ) # metric mismatches - (not appropriate for bernoulli) with open(os.path.join(DATAFILES_PATH, 'metric_diag.data.json')) as fd: metric_dict_1 = json.load(fd) with open(os.path.join(DATAFILES_PATH, 'metric_dense.data.json')) as fd: metric_dict_2 = json.load(fd) - with pytest.raises( - ValueError, match='Found inconsistent "inv_metric" entry' - ): + with pytest.raises(RuntimeError, match='Error during sampling'): bern_model.sample( data=jdata, chains=2, seed=12345, - iter_warmup=100, - iter_sampling=200, - metric=[metric_dict_1, metric_dict_2], + iter_warmup=10, + iter_sampling=10, + inv_metric=[metric_dict_1, metric_dict_2], + force_one_process_per_chain=force_one_process_per_chain, ) # metric dict, no "inv_metric": some_dict = {"foo": [1, 2, 3]} @@ -1090,7 +1124,8 @@ def test_custom_metric() -> None: seed=12345, iter_warmup=100, iter_sampling=200, - metric=some_dict, + inv_metric=some_dict, + force_one_process_per_chain=force_one_process_per_chain, ) @@ -2136,8 +2171,8 @@ def test_sample_dense_mass_matrix(): linear_model = CmdStanModel(stan_file=stan) fit = linear_model.sample(data=jdata, metric="dense_e", chains=2) - assert fit.metric is not None - assert fit.metric.shape == (2, 3, 3) + assert fit.inv_metric is not None + assert fit.inv_metric.shape == (2, 3, 3) def test_no_output_draws(): diff --git a/test/test_utils.py b/test/test_utils.py index 5a4d6167..5eebbb62 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -453,6 +453,21 @@ def test_metric_missing() -> None: read_metric(metric_file) +def test_deduce_metric_type() -> None: + assert stancsv.try_deduce_metric_type(np.zeros((3, 3))) == 'dense_e' + assert stancsv.try_deduce_metric_type(np.zeros((3,))) == 'diag_e' + + assert stancsv.try_deduce_metric_type([np.zeros((3, 3))]) == 'dense_e' + assert ( + stancsv.try_deduce_metric_type({"inv_metric": np.zeros((3,))}) + == 'diag_e' + ) + assert ( + stancsv.try_deduce_metric_type([{"inv_metric": np.zeros((3,))}]) + == 'diag_e' + ) + + @mark_windows_only def test_windows_short_path_directory() -> None: with tempfile.TemporaryDirectory(