Skip to content

Commit ddb203a

Browse files
committed
Deprecate CmdStanMCMC.metric, provide .inv_metric instead
1 parent 71d97ca commit ddb203a

File tree

6 files changed

+31
-35
lines changed

6 files changed

+31
-35
lines changed

cmdstanpy/stanfit/mcmc.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -233,19 +233,27 @@ def metric_type(self) -> Optional[str]:
233233
else None
234234
)
235235

236+
# TODO(2.0): remove
236237
@property
237238
def metric(self) -> Optional[np.ndarray]:
239+
"""Deprecated. Use ``.inv_metric`` instead."""
240+
get_logger().warning(
241+
'The "metric" property is deprecated, use "inv_metric" instead. '
242+
'This will be the same quantity, but with a more accurate name.'
243+
)
244+
return self.inv_metric
245+
246+
@property
247+
def inv_metric(self) -> Optional[np.ndarray]:
238248
"""
239-
Metric used by sampler for each chain.
240-
When sampler algorithm 'fixed_param' is specified, metric is None.
249+
Inverse mass matrix used by sampler for each chain.
250+
Returns a ``nchains x nparams`` array when metric_type is 'diag_e',
251+
a ``nchains x nparams x nparams`` array when metric_type is 'dense_e',
252+
or ``None`` when metric_type is 'unit_e' or algorithm is 'fixed_param'.
241253
"""
242-
if self._is_fixed_param:
243-
return None
244-
if self._metadata.cmdstan_config['metric'] == 'unit_e':
245-
get_logger().info(
246-
'Unit diagnonal metric, inverse mass matrix size unknown.'
247-
)
254+
if self._is_fixed_param or self.metric_type == 'unit_e':
248255
return None
256+
249257
self._assemble_draws()
250258
return self._metric
251259

cmdstanpy_tutorial.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@
352352
"metadata": {},
353353
"outputs": [],
354354
"source": [
355-
"fit.metric_type, fit.metric"
355+
"fit.metric_type, fit.inv_metric"
356356
]
357357
},
358358
{

cmdstanpy_tutorial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737

3838
print(fit.step_size)
3939
print(fit.metric_type)
40-
print(fit.metric)
40+
print(fit.inv_metric)
4141

4242
# #### Summarize the results
4343

docsrc/users-guide/examples/MCMC Sampling.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1483,7 +1483,7 @@
14831483
},
14841484
{
14851485
"cell_type": "code",
1486-
"execution_count": 25,
1486+
"execution_count": null,
14871487
"metadata": {},
14881488
"outputs": [
14891489
{
@@ -1502,7 +1502,7 @@
15021502
}
15031503
],
15041504
"source": [
1505-
"print(f'adapted step_size per chain\\n{fit.step_size}\\nmetric_type: {fit.metric_type}\\nmetric:\\n{fit.metric}')"
1505+
"print(f'adapted step_size per chain\\n{fit.step_size}\\nmetric_type: {fit.metric_type}\\ninverse metric:\\n{fit.inv_metric}')"
15061506
]
15071507
},
15081508
{

docsrc/users-guide/hello_world.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ access to the the per-chain HMC tuning parameters from the NUTS-HMC adaptive sam
171171
.. ipython:: python
172172
173173
print(fit.metric_type)
174-
print(fit.metric)
174+
print(fit.inv_metric)
175175
print(fit.step_size)
176176
177177

test/test_sample.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def test_bernoulli_good(stanfile: str):
8989
assert bern_fit.draws().shape == (100, 2, len(BERNOULLI_COLS))
9090
assert bern_fit.metric_type == 'diag_e'
9191
assert bern_fit.step_size.shape == (2,)
92-
assert bern_fit.metric.shape == (2, 1)
92+
assert bern_fit.inv_metric.shape == (2, 1)
9393

9494
assert bern_fit.draws(concat_chains=True).shape == (
9595
200,
@@ -125,7 +125,7 @@ def test_bernoulli_good(stanfile: str):
125125
assert bern_sample.shape == (100, 2, len(BERNOULLI_COLS))
126126
assert bern_fit.metric_type == 'dense_e'
127127
assert bern_fit.step_size.shape == (2,)
128-
assert bern_fit.metric.shape == (2, 1, 1)
128+
assert bern_fit.inv_metric.shape == (2, 1, 1)
129129

130130
bern_fit = bern_model.sample(
131131
data=jdata,
@@ -186,9 +186,7 @@ def test_bernoulli_good(stanfile: str):
186186

187187

188188
@pytest.mark.parametrize("stanfile", ["bernoulli.stan"])
189-
def test_bernoulli_unit_e(
190-
stanfile: str, caplog: pytest.LogCaptureFixture
191-
) -> None:
189+
def test_bernoulli_unit_e(stanfile: str) -> None:
192190
stan = os.path.join(DATAFILES_PATH, stanfile)
193191
bern_model = CmdStanModel(stan_file=stan)
194192

@@ -204,19 +202,9 @@ def test_bernoulli_unit_e(
204202
show_progress=False,
205203
)
206204
assert bern_fit.metric_type == 'unit_e'
207-
assert bern_fit.metric is None
205+
assert bern_fit.inv_metric is None
208206
assert bern_fit.step_size.shape == (2,)
209-
with caplog.at_level(logging.INFO):
210-
logging.getLogger()
211-
assert bern_fit.metric is None
212-
check_present(
213-
caplog,
214-
(
215-
'cmdstanpy',
216-
'INFO',
217-
'Unit diagnonal metric, inverse mass matrix size unknown.',
218-
),
219-
)
207+
220208
assert bern_fit.draws().shape == (100, 2, len(BERNOULLI_COLS))
221209

222210

@@ -535,7 +523,7 @@ def test_fixed_param_good() -> None:
535523
)
536524
assert datagen_fit.runset._args.method == Method.SAMPLE
537525
assert datagen_fit.metric_type is None
538-
assert datagen_fit.metric is None
526+
assert datagen_fit.inv_metric is None
539527
assert datagen_fit.step_size is None
540528
assert datagen_fit.divergences is None
541529
assert datagen_fit.max_treedepths is None
@@ -638,7 +626,7 @@ def test_fixed_param_good() -> None:
638626
assert datagen_fit.column_names == tuple(column_names)
639627
assert datagen_fit.num_draws_sampling == 100
640628
assert datagen_fit.draws().shape == (100, 1, len(column_names))
641-
assert datagen_fit.metric is None
629+
assert datagen_fit.inv_metric is None
642630
assert datagen_fit.metric_type is None
643631
assert datagen_fit.step_size is None
644632

@@ -860,7 +848,7 @@ def test_validate_big_run() -> None:
860848
assert fit.column_names == tuple(column_names)
861849
assert fit.metric_type == 'diag_e'
862850
assert fit.step_size.shape == (2,)
863-
assert fit.metric.shape == (2, 2095)
851+
assert fit.inv_metric.shape == (2, 2095)
864852
assert fit.draws().shape == (1000, 2, 2102)
865853
assert fit.draws_pd(vars=['phi']).shape == (2000, 2095)
866854
with raises_nested(ValueError, r'Unknown variable: gamma'):
@@ -2136,8 +2124,8 @@ def test_sample_dense_mass_matrix():
21362124
linear_model = CmdStanModel(stan_file=stan)
21372125

21382126
fit = linear_model.sample(data=jdata, metric="dense_e", chains=2)
2139-
assert fit.metric is not None
2140-
assert fit.metric.shape == (2, 3, 3)
2127+
assert fit.inv_metric is not None
2128+
assert fit.inv_metric.shape == (2, 3, 3)
21412129

21422130

21432131
def test_no_output_draws():

0 commit comments

Comments
 (0)