Skip to content

Commit 82a85d2

Browse files
committed
Remove unused chain_id from MetricInfo
1 parent 0f5dab8 commit 82a85d2

File tree

3 files changed

+19
-53
lines changed

3 files changed

+19
-53
lines changed

cmdstanpy/stanfit/mcmc.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -400,12 +400,13 @@ def _validate_csv_files(self) -> dict[str, Any]:
400400
def _parse_metric_info(self) -> None:
401401
"""Extracts metric type, inv_metric, and step size information from the
402402
parsed metric JSONs."""
403-
self._chain_metric_info = [
404-
MetricInfo.from_json(mf, chain_id)
405-
for mf, chain_id in zip(
406-
self.runset.metric_files, self.runset.chain_ids
407-
)
408-
]
403+
self._chain_metric_info = []
404+
for mf in self.runset.metric_files:
405+
with open(mf) as f:
406+
self._chain_metric_info.append(
407+
MetricInfo.model_validate_json(f.read())
408+
)
409+
409410
metric_types = {cmi.metric_type for cmi in self._chain_metric_info}
410411
if len(metric_types) != 1:
411412
raise ValueError("Inconsistent metric types found across chains")

cmdstanpy/stanfit/metadata.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@
33
from __future__ import annotations
44

55
import copy
6-
import json
76
import math
87
import os
98
from typing import Any, Iterator, Literal
109

1110
import numpy as np
1211
import stanio
13-
from pydantic import BaseModel, Field, field_validator, model_validator
12+
from pydantic import BaseModel, field_validator, model_validator
1413

1514
from cmdstanpy.utils import stancsv
1615

@@ -91,7 +90,6 @@ class MetricInfo(BaseModel):
9190
"""Structured representation of HMC-NUTS metric information,
9291
as output by CmdStan"""
9392

94-
chain_id: int = Field(gt=0)
9593
stepsize: float
9694
metric_type: Literal["diag_e", "dense_e", "unit_e"]
9795
inv_metric: np.ndarray
@@ -126,12 +124,3 @@ def validate_inv_metric_shape(self) -> MetricInfo:
126124
raise ValueError("Dense inv_metric must be square")
127125

128126
return self
129-
130-
@classmethod
131-
def from_json(cls, file: str | os.PathLike, chain_id: int) -> MetricInfo:
132-
"""Parse and validate a metric json given a file path and chain_id"""
133-
with open(file) as f:
134-
info_dict = json.load(f)
135-
136-
info_dict['chain_id'] = chain_id
137-
return cls.model_validate(info_dict) # type: ignore

test/test_metadata.py

Lines changed: 11 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Metadata tests"""
1+
"Metadata tests"
22

33
import json
44
import math
@@ -84,20 +84,17 @@ class TestMetricInfoValidators:
8484
def test_valid_diag_e_metric(self) -> None:
8585
"""Test valid diag_e metric with 1D array"""
8686
metric = MetricInfo(
87-
chain_id=1,
8887
stepsize=0.5,
8988
metric_type="diag_e",
9089
inv_metric=[1.0, 2.0, 3.0], # type: ignore
9190
)
92-
assert metric.chain_id == 1
9391
assert metric.stepsize == 0.5
9492
assert isinstance(metric.inv_metric, np.ndarray)
9593
assert metric.inv_metric.ndim == 1
9694

9795
def test_valid_unit_e_metric(self) -> None:
9896
"""Test valid unit_e metric with 1D array"""
9997
metric = MetricInfo(
100-
chain_id=2,
10198
stepsize=0.1,
10299
metric_type="unit_e",
103100
inv_metric=[1.0, 1.0, 1.0], # type: ignore
@@ -108,7 +105,6 @@ def test_valid_unit_e_metric(self) -> None:
108105
def test_valid_dense_e_metric(self) -> None:
109106
"""Test valid dense_e metric with 2D square array"""
110107
metric = MetricInfo(
111-
chain_id=1,
112108
stepsize=0.3,
113109
metric_type="dense_e",
114110
inv_metric=[[1.0, 0.5], [0.5, 1.0]], # type: ignore
@@ -120,7 +116,6 @@ def test_valid_dense_e_metric(self) -> None:
120116
def test_convert_inv_metric_from_list(self) -> None:
121117
"""Test that inv_metric is converted to numpy array from list"""
122118
metric = MetricInfo(
123-
chain_id=1,
124119
stepsize=0.5,
125120
metric_type="diag_e",
126121
inv_metric=[1.0, 2.0, 3.0], # type: ignore
@@ -130,7 +125,6 @@ def test_convert_inv_metric_from_list(self) -> None:
130125
def test_convert_inv_metric_from_nested_list(self) -> None:
131126
"""Test that inv_metric is converted to numpy array from nested list"""
132127
metric = MetricInfo(
133-
chain_id=1,
134128
stepsize=0.5,
135129
metric_type="dense_e",
136130
inv_metric=[[1.0, 0.0], [0.0, 1.0]], # type: ignore
@@ -140,7 +134,6 @@ def test_convert_inv_metric_from_nested_list(self) -> None:
140134
def test_stepsize_positive(self) -> None:
141135
"""Test valid positive stepsize"""
142136
metric = MetricInfo(
143-
chain_id=1,
144137
stepsize=0.5,
145138
metric_type="diag_e",
146139
inv_metric=[1.0], # type: ignore
@@ -150,7 +143,6 @@ def test_stepsize_positive(self) -> None:
150143
def test_stepsize_nan_allowed(self) -> None:
151144
"""Test that NaN stepsize is allowed"""
152145
metric = MetricInfo(
153-
chain_id=1,
154146
stepsize=math.nan,
155147
metric_type="diag_e",
156148
inv_metric=[1.0], # type: ignore
@@ -161,7 +153,6 @@ def test_stepsize_zero_raises_error(self) -> None:
161153
"""Test that zero stepsize raises ValueError"""
162154
with pytest.raises(ValidationError) as exc_info:
163155
MetricInfo(
164-
chain_id=1,
165156
stepsize=0.0,
166157
metric_type="diag_e",
167158
inv_metric=[1.0], # type: ignore
@@ -172,7 +163,6 @@ def test_stepsize_negative_raises_error(self) -> None:
172163
"""Test that negative stepsize raises ValueError"""
173164
with pytest.raises(ValidationError) as exc_info:
174165
MetricInfo(
175-
chain_id=1,
176166
stepsize=-0.5,
177167
metric_type="diag_e",
178168
inv_metric=[1.0], # type: ignore
@@ -183,7 +173,6 @@ def test_diag_e_with_2d_array_raises_error(self) -> None:
183173
"""Test that diag_e with 2D array raises ValueError"""
184174
with pytest.raises(ValidationError) as exc_info:
185175
MetricInfo(
186-
chain_id=1,
187176
stepsize=0.5,
188177
metric_type="diag_e",
189178
inv_metric=[[1.0, 2.0]], # type: ignore
@@ -196,7 +185,6 @@ def test_unit_e_with_2d_array_raises_error(self) -> None:
196185
"""Test that unit_e with 2D array raises ValueError"""
197186
with pytest.raises(ValidationError) as exc_info:
198187
MetricInfo(
199-
chain_id=1,
200188
stepsize=0.5,
201189
metric_type="unit_e",
202190
inv_metric=[[1.0], [1.0]], # type: ignore
@@ -209,7 +197,6 @@ def test_dense_e_with_1d_array_raises_error(self) -> None:
209197
"""Test that dense_e with 1D array raises ValueError"""
210198
with pytest.raises(ValidationError) as exc_info:
211199
MetricInfo(
212-
chain_id=1,
213200
stepsize=0.5,
214201
metric_type="dense_e",
215202
inv_metric=[1.0, 2.0], # type: ignore
@@ -220,27 +207,15 @@ def test_dense_e_non_square_raises_error(self) -> None:
220207
"""Test that dense_e with non-square array raises ValueError"""
221208
with pytest.raises(ValidationError) as exc_info:
222209
MetricInfo(
223-
chain_id=1,
224210
stepsize=0.5,
225211
metric_type="dense_e",
226212
inv_metric=[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], # type: ignore
227213
)
228214
assert "Dense inv_metric must be square" in str(exc_info.value)
229215

230-
def test_chain_id_must_be_positive(self) -> None:
231-
"""Test that chain_id must be greater than 0"""
232-
with pytest.raises(ValidationError) as exc_info:
233-
MetricInfo(
234-
chain_id=0,
235-
stepsize=0.5,
236-
metric_type="diag_e",
237-
inv_metric=[1.0], # type: ignore
238-
)
239-
assert "greater than 0" in str(exc_info.value)
240-
241216

242-
class TestMetricInfoFromJson:
243-
"""Test from_json class method"""
217+
class TestMetricInfoModelValidateJson:
218+
"""Test model_validate_json class method"""
244219

245220
def test_from_json_diag_e(self) -> None:
246221
"""Test loading diag_e metric from JSON file"""
@@ -258,8 +233,8 @@ def test_from_json_diag_e(self) -> None:
258233
temp_path = f.name
259234

260235
try:
261-
metric = MetricInfo.from_json(temp_path, chain_id=1)
262-
assert metric.chain_id == 1
236+
with open(temp_path) as f:
237+
metric = MetricInfo.model_validate_json(f.read())
263238
assert metric.stepsize == 0.5
264239
assert metric.metric_type == "diag_e"
265240
assert np.array_equal(metric.inv_metric, np.array([1.0, 2.0, 3.0]))
@@ -282,8 +257,8 @@ def test_from_json_dense_e(self) -> None:
282257
temp_path = f.name
283258

284259
try:
285-
metric = MetricInfo.from_json(temp_path, chain_id=2)
286-
assert metric.chain_id == 2
260+
with open(temp_path) as f:
261+
metric = MetricInfo.model_validate_json(f.read())
287262
assert metric.stepsize == 0.3
288263
assert metric.metric_type == "dense_e"
289264
assert metric.inv_metric.shape == (2, 2)
@@ -307,7 +282,8 @@ def test_from_json_invalid_data_raises_error(self) -> None:
307282

308283
try:
309284
with pytest.raises(ValidationError):
310-
MetricInfo.from_json(temp_path, chain_id=1)
285+
with open(temp_path) as f:
286+
MetricInfo.model_validate_json(f.read())
311287
finally:
312288
Path(temp_path).unlink()
313289

@@ -327,8 +303,8 @@ def test_from_json_pathlike(self) -> None:
327303
temp_path = Path(f.name)
328304

329305
try:
330-
metric = MetricInfo.from_json(temp_path, chain_id=3)
331-
assert metric.chain_id == 3
306+
with open(temp_path) as f:
307+
metric = MetricInfo.model_validate_json(f.read())
332308
assert metric.metric_type == "unit_e"
333309
finally:
334310
temp_path.unlink()

0 commit comments

Comments
 (0)