Skip to content

Commit 4db7ecb

Browse files
committed
Rework how initial inverse mass matrix can be supplied, deprecate former overloading of metric argument
1 parent ddb203a commit 4db7ecb

File tree

7 files changed

+211
-201
lines changed

7 files changed

+211
-201
lines changed

cmdstanpy/cmdstan_args.py

Lines changed: 18 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,16 @@
11
"""
22
CmdStan arguments
33
"""
4+
45
import os
56
from enum import Enum, auto
67
from time import time
7-
from typing import Any, Dict, List, Mapping, Optional, Union
8+
from typing import Any, List, Mapping, Optional, Union
89

910
import numpy as np
1011
from numpy.random import default_rng
1112

12-
from cmdstanpy import _TMPDIR
13-
from cmdstanpy.utils import (
14-
cmdstan_path,
15-
cmdstan_version_before,
16-
create_named_text_file,
17-
get_logger,
18-
read_metric,
19-
write_stan_json,
20-
)
13+
from cmdstanpy.utils import cmdstan_path, cmdstan_version_before, get_logger
2114

2215
OptionalPath = Union[str, os.PathLike, None]
2316

@@ -64,9 +57,8 @@ def __init__(
6457
save_warmup: bool = False,
6558
thin: Optional[int] = None,
6659
max_treedepth: Optional[int] = None,
67-
metric: Union[
68-
str, Dict[str, Any], List[str], List[Dict[str, Any]], None
69-
] = None,
60+
metric_type: Optional[str] = None,
61+
metric_file: Union[str, List[str], None] = None,
7062
step_size: Union[float, List[float], None] = None,
7163
adapt_engaged: bool = True,
7264
adapt_delta: Optional[float] = None,
@@ -82,9 +74,8 @@ def __init__(
8274
self.save_warmup = save_warmup
8375
self.thin = thin
8476
self.max_treedepth = max_treedepth
85-
self.metric = metric
86-
self.metric_type: Optional[str] = None
87-
self.metric_file: Union[str, List[str], None] = None
77+
self.metric_type: Optional[str] = metric_type
78+
self.metric_file: Union[str, List[str], None] = metric_file
8879
self.step_size = step_size
8980
self.adapt_engaged = adapt_engaged
9081
self.adapt_delta = adapt_delta
@@ -176,124 +167,15 @@ def validate(self, chains: Optional[int]) -> None:
176167
'Argument "step_size" must be > 0, '
177168
'chain {}, found {}.'.format(i + 1, step_size)
178169
)
179-
if self.metric is not None:
180-
if isinstance(self.metric, str):
181-
if self.metric in ['diag', 'diag_e']:
182-
self.metric_type = 'diag_e'
183-
elif self.metric in ['dense', 'dense_e']:
184-
self.metric_type = 'dense_e'
185-
elif self.metric in ['unit', 'unit_e']:
186-
self.metric_type = 'unit_e'
187-
else:
188-
if not os.path.exists(self.metric):
189-
raise ValueError('no such file {}'.format(self.metric))
190-
dims = read_metric(self.metric)
191-
if len(dims) == 1:
192-
self.metric_type = 'diag_e'
193-
else:
194-
self.metric_type = 'dense_e'
195-
self.metric_file = self.metric
196-
elif isinstance(self.metric, dict):
197-
if 'inv_metric' not in self.metric:
198-
raise ValueError(
199-
'Entry "inv_metric" not found in metric dict.'
200-
)
201-
dims = list(np.asarray(self.metric['inv_metric']).shape)
202-
if len(dims) == 1:
203-
self.metric_type = 'diag_e'
204-
else:
205-
self.metric_type = 'dense_e'
206-
dict_file = create_named_text_file(
207-
dir=_TMPDIR, prefix="metric", suffix=".json"
208-
)
209-
write_stan_json(dict_file, self.metric)
210-
self.metric_file = dict_file
211-
elif isinstance(self.metric, (list, tuple)):
212-
if len(self.metric) != chains:
213-
raise ValueError(
214-
'Number of metric files must match number of chains,'
215-
' found {} metric files for {} chains.'.format(
216-
len(self.metric), chains
217-
)
218-
)
219-
if all(isinstance(elem, dict) for elem in self.metric):
220-
metric_files: List[str] = []
221-
for i, metric in enumerate(self.metric):
222-
metric_dict: Dict[str, Any] = metric # type: ignore
223-
if 'inv_metric' not in metric_dict:
224-
raise ValueError(
225-
'Entry "inv_metric" not found in metric dict '
226-
'for chain {}.'.format(i + 1)
227-
)
228-
if i == 0:
229-
dims = list(
230-
np.asarray(metric_dict['inv_metric']).shape
231-
)
232-
else:
233-
dims2 = list(
234-
np.asarray(metric_dict['inv_metric']).shape
235-
)
236-
if dims != dims2:
237-
raise ValueError(
238-
'Found inconsistent "inv_metric" entry '
239-
'for chain {}: entry has dims '
240-
'{}, expected {}.'.format(
241-
i + 1, dims, dims2
242-
)
243-
)
244-
dict_file = create_named_text_file(
245-
dir=_TMPDIR, prefix="metric", suffix=".json"
246-
)
247-
write_stan_json(dict_file, metric_dict)
248-
metric_files.append(dict_file)
249-
if len(dims) == 1:
250-
self.metric_type = 'diag_e'
251-
else:
252-
self.metric_type = 'dense_e'
253-
self.metric_file = metric_files
254-
elif all(isinstance(elem, str) for elem in self.metric):
255-
metric_files = []
256-
for i, metric in enumerate(self.metric):
257-
assert isinstance(metric, str) # typecheck
258-
if not os.path.exists(metric):
259-
raise ValueError('no such file {}'.format(metric))
260-
if i == 0:
261-
dims = read_metric(metric)
262-
else:
263-
dims2 = read_metric(metric)
264-
if len(dims) != len(dims2):
265-
raise ValueError(
266-
'Metrics files {}, {},'
267-
' inconsistent metrics'.format(
268-
self.metric[0], metric
269-
)
270-
)
271-
if dims != dims2:
272-
raise ValueError(
273-
'Metrics files {}, {},'
274-
' inconsistent metrics'.format(
275-
self.metric[0], metric
276-
)
277-
)
278-
metric_files.append(metric)
279-
if len(dims) == 1:
280-
self.metric_type = 'diag_e'
281-
else:
282-
self.metric_type = 'dense_e'
283-
self.metric_file = metric_files
284-
else:
285-
raise ValueError(
286-
'Argument "metric" must be a list of pathnames or '
287-
'Python dicts, found list of {}.'.format(
288-
type(self.metric[0])
289-
)
290-
)
291-
else:
170+
if self.metric_type is not None:
171+
if self.metric_type in ['diag', 'dense', 'unit']:
172+
self.metric_type += '_e'
173+
if self.metric_type not in ['diag_e', 'dense_e', 'unit_e']:
292174
raise ValueError(
293-
'Invalid metric specified, not a recognized metric type, '
294-
'must be either a metric type name, a filepath, dict, '
295-
'or list of per-chain filepaths or dicts. Found '
296-
'an object of type {}.'.format(type(self.metric))
175+
'Argument "metric" must be one of [diag, dense, unit,'
176+
' diag_e, dense_e, unit_e], found {}.'.format(
177+
self.metric_type
178+
)
297179
)
298180

299181
if self.adapt_delta is not None:
@@ -330,7 +212,8 @@ def validate(self, chains: Optional[int]) -> None:
330212

331213
if self.fixed_param and (
332214
self.max_treedepth is not None
333-
or self.metric is not None
215+
or self.metric_type is not None
216+
or self.metric_file is not None
334217
or self.step_size is not None
335218
or not (
336219
self.adapt_delta is None
@@ -369,7 +252,7 @@ def compose(self, idx: int, cmd: List[str]) -> List[str]:
369252
cmd.append(f'stepsize={self.step_size}')
370253
else:
371254
cmd.append(f'stepsize={self.step_size[idx]}')
372-
if self.metric is not None:
255+
if self.metric_type is not None:
373256
cmd.append(f'metric={self.metric_type}')
374257
if self.metric_file is not None:
375258
if not isinstance(self.metric_file, list):

cmdstanpy/model.py

Lines changed: 83 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
Union,
2727
)
2828

29+
import numpy as np
2930
import pandas as pd
3031
from tqdm.auto import tqdm
3132

@@ -57,7 +58,12 @@
5758
get_logger,
5859
returncode_msg,
5960
)
60-
from cmdstanpy.utils.filesystem import temp_inits, temp_single_json
61+
from cmdstanpy.utils.filesystem import (
62+
temp_inits,
63+
temp_metrics,
64+
temp_single_json,
65+
)
66+
from cmdstanpy.utils.stancsv import try_deduce_metric_type
6167

6268
from . import progress as progbar
6369

@@ -698,6 +704,13 @@ def sample(
698704
timeout: Optional[float] = None,
699705
*,
700706
force_one_process_per_chain: Optional[bool] = None,
707+
inv_metric: Union[
708+
str,
709+
np.ndarray,
710+
Mapping[str, Any],
711+
List[Union[str, np.ndarray, Mapping[str, Any]]],
712+
None,
713+
] = None,
701714
) -> CmdStanMCMC:
702715
"""
703716
Run or more chains of the NUTS-HMC sampler to produce a set of draws
@@ -786,29 +799,23 @@ def sample(
786799
:param max_treedepth: Maximum depth of trees evaluated by NUTS sampler
787800
per iteration.
788801
789-
:param metric: Specification of the mass matrix, either as a
790-
vector consisting of the diagonal elements of the covariance
791-
matrix ('diag' or 'diag_e') or the full covariance matrix
792-
('dense' or 'dense_e').
793-
794-
If the value of the metric argument is a string other than
795-
'diag', 'diag_e', 'dense', or 'dense_e', it must be
796-
a valid filepath to a JSON or Rdump file which contains an entry
797-
'inv_metric' whose value is either the diagonal vector or
798-
the full covariance matrix.
802+
:param metric: Specify the type of the mass matrix. Options are
803+
'diag' or 'diag_e' for diagonal matrix, 'dense' or 'dense_e'
804+
for a dense matrix, or 'unit_e' an identity mass matrix.
799805
800-
If the value of the metric argument is a list of paths, its
801-
length must match the number of chains and all paths must be
802-
unique.
806+
:param inv_metric: Provide an initial value for the inverse
807+
mass matrix.
803808
804-
If the value of the metric argument is a Python dict object, it
805-
must contain an entry 'inv_metric' which specifies either the
806-
diagnoal or dense matrix.
807-
808-
If the value of the metric argument is a list of Python dicts,
809-
its length must match the number of chains and all dicts must
810-
containan entry 'inv_metric' and all 'inv_metric' entries must
811-
have the same shape.
809+
Valid options include:
810+
- a string, which must be a valid filepath to a JSON or
811+
Rdump file which contains an entry 'inv_metric' whose value
812+
is either a diagonal vector or dense matrix.
813+
- a numpy array containing either the diagonal vector or dense
814+
matrix.
815+
- a dictionary containing an entry 'inv_metric' whose value
816+
is either a diagonal vector or dense matrix.
817+
- a list of any of the above, of length num_chains, with
818+
the same shape of metric in each entry.
812819
813820
:param step_size: Initial step size for HMC sampler. The value is
814821
either a single number or a list of numbers which will be used
@@ -1002,34 +1009,71 @@ def sample(
10021009
'Chain_id must be a non-negative integer value,'
10031010
' found {}.'.format(chain_id)
10041011
)
1012+
if metric not in [None, 'diag', 'dense', 'unit_e', 'diag_e', 'dense_e']:
1013+
get_logger().warning(
1014+
"Providing anything other than metric type for"
1015+
" 'metric' is deprecated and will be removed"
1016+
" in the next major release."
1017+
" Please provide such information via"
1018+
" 'inv_metric' argument."
1019+
)
1020+
if inv_metric is not None:
1021+
raise ValueError(
1022+
"Cannot provide both (deprecated) non-metric-type 'metric'"
1023+
" argument and 'inv_metric' argument."
1024+
)
1025+
inv_metric = metric # type: ignore # for backwards compatibility
1026+
metric = None
10051027

1006-
sampler_args = SamplerArgs(
1007-
num_chains=1 if one_process_per_chain else chains,
1008-
iter_warmup=iter_warmup,
1009-
iter_sampling=iter_sampling,
1010-
save_warmup=save_warmup,
1011-
thin=thin,
1012-
max_treedepth=max_treedepth,
1013-
metric=metric,
1014-
step_size=step_size,
1015-
adapt_engaged=adapt_engaged,
1016-
adapt_delta=adapt_delta,
1017-
adapt_init_phase=adapt_init_phase,
1018-
adapt_metric_window=adapt_metric_window,
1019-
adapt_step_size=adapt_step_size,
1020-
fixed_param=fixed_param,
1021-
)
1028+
if metric is None and inv_metric is not None:
1029+
metric = try_deduce_metric_type(inv_metric)
1030+
1031+
if isinstance(inv_metric, list):
1032+
if not len(inv_metric) == chains:
1033+
raise ValueError(
1034+
'Number of metric files must match number of chains,'
1035+
' found {} metric files for {} chains.'.format(
1036+
len(inv_metric), chains
1037+
)
1038+
)
10221039

10231040
with temp_single_json(data) as _data, temp_inits(
10241041
inits, id=chain_ids[0]
1025-
) as _inits:
1042+
) as _inits, temp_metrics(inv_metric, id=chain_ids[0]) as _inv_metric:
10261043
cmdstan_inits: Union[str, List[str], int, float, None]
1044+
cmdstan_metrics: Union[str, List[str], None]
1045+
10271046
if one_process_per_chain and isinstance(inits, list): # legacy
10281047
cmdstan_inits = [
10291048
f"{_inits[:-5]}_{i}.json" for i in chain_ids # type: ignore
10301049
]
10311050
else:
10321051
cmdstan_inits = _inits
1052+
if one_process_per_chain and isinstance(inv_metric, list): # legacy
1053+
cmdstan_metrics = [
1054+
f"{_inv_metric[:-5]}_{i}.json" # type: ignore
1055+
for i in chain_ids
1056+
]
1057+
else:
1058+
cmdstan_metrics = _inv_metric
1059+
1060+
sampler_args = SamplerArgs(
1061+
num_chains=1 if one_process_per_chain else chains,
1062+
iter_warmup=iter_warmup,
1063+
iter_sampling=iter_sampling,
1064+
save_warmup=save_warmup,
1065+
thin=thin,
1066+
max_treedepth=max_treedepth,
1067+
metric_type=metric, # type: ignore
1068+
metric_file=cmdstan_metrics,
1069+
step_size=step_size,
1070+
adapt_engaged=adapt_engaged,
1071+
adapt_delta=adapt_delta,
1072+
adapt_init_phase=adapt_init_phase,
1073+
adapt_metric_window=adapt_metric_window,
1074+
adapt_step_size=adapt_step_size,
1075+
fixed_param=fixed_param,
1076+
)
10331077

10341078
args = CmdStanArgs(
10351079
self._name,

0 commit comments

Comments
 (0)