Skip to content

Commit 0b08591

Browse files
authored
Merge pull request #804 from stan-dev/deprecations/rename-to-inv_metric
Deprecate metric argument and property, rename to inv_metric
2 parents 54a0416 + cd5b5de commit 0b08591

File tree

12 files changed

+312
-259
lines changed

12 files changed

+312
-259
lines changed

cmdstanpy/cmdstan_args.py

Lines changed: 16 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,7 @@
1010
import numpy as np
1111
from numpy.random import default_rng
1212

13-
from cmdstanpy import _TMPDIR
14-
from cmdstanpy.utils import (
15-
cmdstan_path,
16-
cmdstan_version_before,
17-
create_named_text_file,
18-
get_logger,
19-
read_metric,
20-
write_stan_json,
21-
)
13+
from cmdstanpy.utils import cmdstan_path, cmdstan_version_before, get_logger
2214

2315
OptionalPath = Union[str, os.PathLike, None]
2416

@@ -65,9 +57,8 @@ def __init__(
6557
save_warmup: bool = False,
6658
thin: Optional[int] = None,
6759
max_treedepth: Optional[int] = None,
68-
metric: Union[
69-
str, dict[str, Any], list[str], list[dict[str, Any]], None
70-
] = None,
60+
metric_type: Optional[str] = None,
61+
metric_file: Union[str, list[str], None] = None,
7162
step_size: Union[float, list[float], None] = None,
7263
adapt_engaged: bool = True,
7364
adapt_delta: Optional[float] = None,
@@ -83,9 +74,8 @@ def __init__(
8374
self.save_warmup = save_warmup
8475
self.thin = thin
8576
self.max_treedepth = max_treedepth
86-
self.metric = metric
87-
self.metric_type: Optional[str] = None
88-
self.metric_file: Union[str, list[str], None] = None
77+
self.metric_type: Optional[str] = metric_type
78+
self.metric_file: Union[str, list[str], None] = metric_file
8979
self.step_size = step_size
9080
self.adapt_engaged = adapt_engaged
9181
self.adapt_delta = adapt_delta
@@ -178,124 +168,15 @@ def validate(self, chains: Optional[int]) -> None:
178168
'Argument "step_size" must be > 0, '
179169
'chain {}, found {}.'.format(i + 1, step_size)
180170
)
181-
if self.metric is not None:
182-
if isinstance(self.metric, str):
183-
if self.metric in ['diag', 'diag_e']:
184-
self.metric_type = 'diag_e'
185-
elif self.metric in ['dense', 'dense_e']:
186-
self.metric_type = 'dense_e'
187-
elif self.metric in ['unit', 'unit_e']:
188-
self.metric_type = 'unit_e'
189-
else:
190-
if not os.path.exists(self.metric):
191-
raise ValueError('no such file {}'.format(self.metric))
192-
dims = read_metric(self.metric)
193-
if len(dims) == 1:
194-
self.metric_type = 'diag_e'
195-
else:
196-
self.metric_type = 'dense_e'
197-
self.metric_file = self.metric
198-
elif isinstance(self.metric, dict):
199-
if 'inv_metric' not in self.metric:
200-
raise ValueError(
201-
'Entry "inv_metric" not found in metric dict.'
202-
)
203-
dims = list(np.asarray(self.metric['inv_metric']).shape)
204-
if len(dims) == 1:
205-
self.metric_type = 'diag_e'
206-
else:
207-
self.metric_type = 'dense_e'
208-
dict_file = create_named_text_file(
209-
dir=_TMPDIR, prefix="metric", suffix=".json"
210-
)
211-
write_stan_json(dict_file, self.metric)
212-
self.metric_file = dict_file
213-
elif isinstance(self.metric, (list, tuple)):
214-
if len(self.metric) != chains:
215-
raise ValueError(
216-
'Number of metric files must match number of chains,'
217-
' found {} metric files for {} chains.'.format(
218-
len(self.metric), chains
219-
)
220-
)
221-
if all(isinstance(elem, dict) for elem in self.metric):
222-
metric_files: list[str] = []
223-
for i, metric in enumerate(self.metric):
224-
metric_dict: dict[str, Any] = metric # type: ignore
225-
if 'inv_metric' not in metric_dict:
226-
raise ValueError(
227-
'Entry "inv_metric" not found in metric dict '
228-
'for chain {}.'.format(i + 1)
229-
)
230-
if i == 0:
231-
dims = list(
232-
np.asarray(metric_dict['inv_metric']).shape
233-
)
234-
else:
235-
dims2 = list(
236-
np.asarray(metric_dict['inv_metric']).shape
237-
)
238-
if dims != dims2:
239-
raise ValueError(
240-
'Found inconsistent "inv_metric" entry '
241-
'for chain {}: entry has dims '
242-
'{}, expected {}.'.format(
243-
i + 1, dims, dims2
244-
)
245-
)
246-
dict_file = create_named_text_file(
247-
dir=_TMPDIR, prefix="metric", suffix=".json"
248-
)
249-
write_stan_json(dict_file, metric_dict)
250-
metric_files.append(dict_file)
251-
if len(dims) == 1:
252-
self.metric_type = 'diag_e'
253-
else:
254-
self.metric_type = 'dense_e'
255-
self.metric_file = metric_files
256-
elif all(isinstance(elem, str) for elem in self.metric):
257-
metric_files = []
258-
for i, metric in enumerate(self.metric):
259-
assert isinstance(metric, str) # typecheck
260-
if not os.path.exists(metric):
261-
raise ValueError('no such file {}'.format(metric))
262-
if i == 0:
263-
dims = read_metric(metric)
264-
else:
265-
dims2 = read_metric(metric)
266-
if len(dims) != len(dims2):
267-
raise ValueError(
268-
'Metrics files {}, {},'
269-
' inconsistent metrics'.format(
270-
self.metric[0], metric
271-
)
272-
)
273-
if dims != dims2:
274-
raise ValueError(
275-
'Metrics files {}, {},'
276-
' inconsistent metrics'.format(
277-
self.metric[0], metric
278-
)
279-
)
280-
metric_files.append(metric)
281-
if len(dims) == 1:
282-
self.metric_type = 'diag_e'
283-
else:
284-
self.metric_type = 'dense_e'
285-
self.metric_file = metric_files
286-
else:
287-
raise ValueError(
288-
'Argument "metric" must be a list of pathnames or '
289-
'Python dicts, found list of {}.'.format(
290-
type(self.metric[0])
291-
)
292-
)
293-
else:
171+
if self.metric_type is not None:
172+
if self.metric_type in ['diag', 'dense', 'unit']:
173+
self.metric_type += '_e'
174+
if self.metric_type not in ['diag_e', 'dense_e', 'unit_e']:
294175
raise ValueError(
295-
'Invalid metric specified, not a recognized metric type, '
296-
'must be either a metric type name, a filepath, dict, '
297-
'or list of per-chain filepaths or dicts. Found '
298-
'an object of type {}.'.format(type(self.metric))
176+
'Argument "metric" must be one of [diag, dense, unit,'
177+
' diag_e, dense_e, unit_e], found {}.'.format(
178+
self.metric_type
179+
)
299180
)
300181

301182
if self.adapt_delta is not None:
@@ -332,7 +213,8 @@ def validate(self, chains: Optional[int]) -> None:
332213

333214
if self.fixed_param and (
334215
self.max_treedepth is not None
335-
or self.metric is not None
216+
or self.metric_type is not None
217+
or self.metric_file is not None
336218
or self.step_size is not None
337219
or not (
338220
self.adapt_delta is None
@@ -371,7 +253,7 @@ def compose(self, idx: int, cmd: list[str]) -> list[str]:
371253
cmd.append(f'stepsize={self.step_size}')
372254
else:
373255
cmd.append(f'stepsize={self.step_size[idx]}')
374-
if self.metric is not None:
256+
if self.metric_type is not None:
375257
cmd.append(f'metric={self.metric_type}')
376258
if self.metric_file is not None:
377259
if not isinstance(self.metric_file, list):

cmdstanpy/model.py

Lines changed: 93 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
Union,
2525
)
2626

27+
import numpy as np
2728
import pandas as pd
2829
from tqdm.auto import tqdm
2930

@@ -55,7 +56,12 @@
5556
get_logger,
5657
returncode_msg,
5758
)
58-
from cmdstanpy.utils.filesystem import temp_inits, temp_single_json
59+
from cmdstanpy.utils.filesystem import (
60+
temp_inits,
61+
temp_metrics,
62+
temp_single_json,
63+
)
64+
from cmdstanpy.utils.stancsv import try_deduce_metric_type
5965

6066
from . import progress as progbar
6167

@@ -697,6 +703,13 @@ def sample(
697703
timeout: Optional[float] = None,
698704
*,
699705
force_one_process_per_chain: Optional[bool] = None,
706+
inv_metric: Union[
707+
str,
708+
np.ndarray,
709+
Mapping[str, Any],
710+
list[Union[str, np.ndarray, Mapping[str, Any]]],
711+
None,
712+
] = None,
700713
) -> CmdStanMCMC:
701714
"""
702715
Run or more chains of the NUTS-HMC sampler to produce a set of draws
@@ -785,29 +798,25 @@ def sample(
785798
:param max_treedepth: Maximum depth of trees evaluated by NUTS sampler
786799
per iteration.
787800
788-
:param metric: Specification of the mass matrix, either as a
789-
vector consisting of the diagonal elements of the covariance
790-
matrix ('diag' or 'diag_e') or the full covariance matrix
791-
('dense' or 'dense_e').
792-
793-
If the value of the metric argument is a string other than
794-
'diag', 'diag_e', 'dense', or 'dense_e', it must be
795-
a valid filepath to a JSON or Rdump file which contains an entry
796-
'inv_metric' whose value is either the diagonal vector or
797-
the full covariance matrix.
798-
799-
If the value of the metric argument is a list of paths, its
800-
length must match the number of chains and all paths must be
801-
unique.
802-
803-
If the value of the metric argument is a Python dict object, it
804-
must contain an entry 'inv_metric' which specifies either the
805-
diagnoal or dense matrix.
806-
807-
If the value of the metric argument is a list of Python dicts,
808-
its length must match the number of chains and all dicts must
809-
containan entry 'inv_metric' and all 'inv_metric' entries must
810-
have the same shape.
801+
:param metric: Specify the type of the inverse mass matrix. Options are
802+
'diag' or 'diag_e' for diagonal matrix, 'dense' or 'dense_e'
803+
for a dense matrix, or 'unit_e' an identity mass matrix. To provide
804+
an initial value for the inverse mass matrix, use the ``inv_metric``
805+
argument.
806+
807+
:param inv_metric: Provide an initial value for the inverse
808+
mass matrix.
809+
810+
Valid options include:
811+
- a string, which must be a valid filepath to a JSON or
812+
Rdump file which contains an entry 'inv_metric' whose value
813+
is either a diagonal vector or dense matrix.
814+
- a numpy array containing either the diagonal vector or dense
815+
matrix.
816+
- a dictionary containing an entry 'inv_metric' whose value
817+
is either a diagonal vector or dense matrix.
818+
- a list of any of the above, of length num_chains, with
819+
the same shape of metric in each entry.
811820
812821
:param step_size: Initial step size for HMC sampler. The value is
813822
either a single number or a list of numbers which will be used
@@ -1001,35 +1010,79 @@ def sample(
10011010
'Chain_id must be a non-negative integer value,'
10021011
' found {}.'.format(chain_id)
10031012
)
1013+
if metric is not None and metric not in (
1014+
'diag',
1015+
'dense',
1016+
'unit_e',
1017+
'diag_e',
1018+
'dense_e',
1019+
):
1020+
get_logger().warning(
1021+
"Providing anything other than metric type for"
1022+
" 'metric' is deprecated and will be removed"
1023+
" in the next major release."
1024+
" Please provide such information via"
1025+
" 'inv_metric' argument."
1026+
)
1027+
if inv_metric is not None:
1028+
raise ValueError(
1029+
"Cannot provide both (deprecated) non-metric-type 'metric'"
1030+
" argument and 'inv_metric' argument."
1031+
)
1032+
inv_metric = metric # type: ignore # for backwards compatibility
1033+
metric = None
10041034

1005-
sampler_args = SamplerArgs(
1006-
num_chains=1 if one_process_per_chain else chains,
1007-
iter_warmup=iter_warmup,
1008-
iter_sampling=iter_sampling,
1009-
save_warmup=save_warmup,
1010-
thin=thin,
1011-
max_treedepth=max_treedepth,
1012-
metric=metric,
1013-
step_size=step_size,
1014-
adapt_engaged=adapt_engaged,
1015-
adapt_delta=adapt_delta,
1016-
adapt_init_phase=adapt_init_phase,
1017-
adapt_metric_window=adapt_metric_window,
1018-
adapt_step_size=adapt_step_size,
1019-
fixed_param=fixed_param,
1020-
)
1035+
if metric is None and inv_metric is not None:
1036+
metric = try_deduce_metric_type(inv_metric)
1037+
1038+
if isinstance(inv_metric, list):
1039+
if not len(inv_metric) == chains:
1040+
raise ValueError(
1041+
'Number of metric files must match number of chains,'
1042+
' found {} metric files for {} chains.'.format(
1043+
len(inv_metric), chains
1044+
)
1045+
)
10211046

10221047
with (
10231048
temp_single_json(data) as _data,
10241049
temp_inits(inits, id=chain_ids[0]) as _inits,
1050+
temp_metrics(inv_metric, id=chain_ids[0]) as _inv_metric,
10251051
):
10261052
cmdstan_inits: Union[str, list[str], int, float, None]
1053+
cmdstan_metrics: Union[str, list[str], None]
1054+
10271055
if one_process_per_chain and isinstance(inits, list): # legacy
10281056
cmdstan_inits = [
10291057
f"{_inits[:-5]}_{i}.json" for i in chain_ids # type: ignore
10301058
]
10311059
else:
10321060
cmdstan_inits = _inits
1061+
if one_process_per_chain and isinstance(inv_metric, list): # legacy
1062+
cmdstan_metrics = [
1063+
f"{_inv_metric[:-5]}_{i}.json" # type: ignore
1064+
for i in chain_ids
1065+
]
1066+
else:
1067+
cmdstan_metrics = _inv_metric
1068+
1069+
sampler_args = SamplerArgs(
1070+
num_chains=1 if one_process_per_chain else chains,
1071+
iter_warmup=iter_warmup,
1072+
iter_sampling=iter_sampling,
1073+
save_warmup=save_warmup,
1074+
thin=thin,
1075+
max_treedepth=max_treedepth,
1076+
metric_type=metric, # type: ignore
1077+
metric_file=cmdstan_metrics,
1078+
step_size=step_size,
1079+
adapt_engaged=adapt_engaged,
1080+
adapt_delta=adapt_delta,
1081+
adapt_init_phase=adapt_init_phase,
1082+
adapt_metric_window=adapt_metric_window,
1083+
adapt_step_size=adapt_step_size,
1084+
fixed_param=fixed_param,
1085+
)
10331086

10341087
args = CmdStanArgs(
10351088
self._name,

0 commit comments

Comments
 (0)