Skip to content

Commit 0f5dab8

Browse files
committed
Add tests of MetricInfo validators
1 parent d80461d commit 0f5dab8

File tree

1 file changed

+265
-0
lines changed

1 file changed

+265
-0
lines changed

test/test_metadata.py

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,18 @@
11
"""Metadata tests"""
22

3+
import json
4+
import math
35
import os
6+
import tempfile
7+
from pathlib import Path
8+
9+
import numpy as np
10+
import pytest
11+
from pydantic import ValidationError
412

513
from cmdstanpy.cmdstan_args import CmdStanArgs, SamplerArgs
614
from cmdstanpy.stanfit import InferenceMetadata, RunSet
15+
from cmdstanpy.stanfit.metadata import MetricInfo
716
from cmdstanpy.utils import EXTENSION, check_sampler_csv
817

918
HERE = os.path.dirname(os.path.abspath(__file__))
@@ -67,3 +76,259 @@ def test_good() -> None:
6776
assert hmc_vars == method_vars_cols.keys()
6877
bern_model_vars = {'theta'}
6978
assert bern_model_vars == metadata.stan_vars.keys()
79+
80+
81+
class TestMetricInfoValidators:
82+
"""Test custom validators for MetricInfo model"""
83+
84+
def test_valid_diag_e_metric(self) -> None:
85+
"""Test valid diag_e metric with 1D array"""
86+
metric = MetricInfo(
87+
chain_id=1,
88+
stepsize=0.5,
89+
metric_type="diag_e",
90+
inv_metric=[1.0, 2.0, 3.0], # type: ignore
91+
)
92+
assert metric.chain_id == 1
93+
assert metric.stepsize == 0.5
94+
assert isinstance(metric.inv_metric, np.ndarray)
95+
assert metric.inv_metric.ndim == 1
96+
97+
def test_valid_unit_e_metric(self) -> None:
98+
"""Test valid unit_e metric with 1D array"""
99+
metric = MetricInfo(
100+
chain_id=2,
101+
stepsize=0.1,
102+
metric_type="unit_e",
103+
inv_metric=[1.0, 1.0, 1.0], # type: ignore
104+
)
105+
assert metric.metric_type == "unit_e"
106+
assert metric.inv_metric.ndim == 1
107+
108+
def test_valid_dense_e_metric(self) -> None:
109+
"""Test valid dense_e metric with 2D square array"""
110+
metric = MetricInfo(
111+
chain_id=1,
112+
stepsize=0.3,
113+
metric_type="dense_e",
114+
inv_metric=[[1.0, 0.5], [0.5, 1.0]], # type: ignore
115+
)
116+
assert metric.metric_type == "dense_e"
117+
assert metric.inv_metric.ndim == 2
118+
assert metric.inv_metric.shape == (2, 2)
119+
120+
def test_convert_inv_metric_from_list(self) -> None:
121+
"""Test that inv_metric is converted to numpy array from list"""
122+
metric = MetricInfo(
123+
chain_id=1,
124+
stepsize=0.5,
125+
metric_type="diag_e",
126+
inv_metric=[1.0, 2.0, 3.0], # type: ignore
127+
)
128+
assert isinstance(metric.inv_metric, np.ndarray)
129+
130+
def test_convert_inv_metric_from_nested_list(self) -> None:
131+
"""Test that inv_metric is converted to numpy array from nested list"""
132+
metric = MetricInfo(
133+
chain_id=1,
134+
stepsize=0.5,
135+
metric_type="dense_e",
136+
inv_metric=[[1.0, 0.0], [0.0, 1.0]], # type: ignore
137+
)
138+
assert isinstance(metric.inv_metric, np.ndarray)
139+
140+
def test_stepsize_positive(self) -> None:
141+
"""Test valid positive stepsize"""
142+
metric = MetricInfo(
143+
chain_id=1,
144+
stepsize=0.5,
145+
metric_type="diag_e",
146+
inv_metric=[1.0], # type: ignore
147+
)
148+
assert metric.stepsize == 0.5
149+
150+
def test_stepsize_nan_allowed(self) -> None:
151+
"""Test that NaN stepsize is allowed"""
152+
metric = MetricInfo(
153+
chain_id=1,
154+
stepsize=math.nan,
155+
metric_type="diag_e",
156+
inv_metric=[1.0], # type: ignore
157+
)
158+
assert math.isnan(metric.stepsize)
159+
160+
def test_stepsize_zero_raises_error(self) -> None:
161+
"""Test that zero stepsize raises ValueError"""
162+
with pytest.raises(ValidationError) as exc_info:
163+
MetricInfo(
164+
chain_id=1,
165+
stepsize=0.0,
166+
metric_type="diag_e",
167+
inv_metric=[1.0], # type: ignore
168+
)
169+
assert "stepsize must be greater than 0 or NaN" in str(exc_info.value)
170+
171+
def test_stepsize_negative_raises_error(self) -> None:
172+
"""Test that negative stepsize raises ValueError"""
173+
with pytest.raises(ValidationError) as exc_info:
174+
MetricInfo(
175+
chain_id=1,
176+
stepsize=-0.5,
177+
metric_type="diag_e",
178+
inv_metric=[1.0], # type: ignore
179+
)
180+
assert "stepsize must be greater than 0 or NaN" in str(exc_info.value)
181+
182+
def test_diag_e_with_2d_array_raises_error(self) -> None:
183+
"""Test that diag_e with 2D array raises ValueError"""
184+
with pytest.raises(ValidationError) as exc_info:
185+
MetricInfo(
186+
chain_id=1,
187+
stepsize=0.5,
188+
metric_type="diag_e",
189+
inv_metric=[[1.0, 2.0]], # type: ignore
190+
)
191+
assert "inv_metric must be 1D for diag_e and unit_e" in str(
192+
exc_info.value
193+
)
194+
195+
def test_unit_e_with_2d_array_raises_error(self) -> None:
196+
"""Test that unit_e with 2D array raises ValueError"""
197+
with pytest.raises(ValidationError) as exc_info:
198+
MetricInfo(
199+
chain_id=1,
200+
stepsize=0.5,
201+
metric_type="unit_e",
202+
inv_metric=[[1.0], [1.0]], # type: ignore
203+
)
204+
assert "inv_metric must be 1D for diag_e and unit_e" in str(
205+
exc_info.value
206+
)
207+
208+
def test_dense_e_with_1d_array_raises_error(self) -> None:
209+
"""Test that dense_e with 1D array raises ValueError"""
210+
with pytest.raises(ValidationError) as exc_info:
211+
MetricInfo(
212+
chain_id=1,
213+
stepsize=0.5,
214+
metric_type="dense_e",
215+
inv_metric=[1.0, 2.0], # type: ignore
216+
)
217+
assert "Dense inv_metric must be 2D" in str(exc_info.value)
218+
219+
def test_dense_e_non_square_raises_error(self) -> None:
220+
"""Test that dense_e with non-square array raises ValueError"""
221+
with pytest.raises(ValidationError) as exc_info:
222+
MetricInfo(
223+
chain_id=1,
224+
stepsize=0.5,
225+
metric_type="dense_e",
226+
inv_metric=[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], # type: ignore
227+
)
228+
assert "Dense inv_metric must be square" in str(exc_info.value)
229+
230+
def test_chain_id_must_be_positive(self) -> None:
231+
"""Test that chain_id must be greater than 0"""
232+
with pytest.raises(ValidationError) as exc_info:
233+
MetricInfo(
234+
chain_id=0,
235+
stepsize=0.5,
236+
metric_type="diag_e",
237+
inv_metric=[1.0], # type: ignore
238+
)
239+
assert "greater than 0" in str(exc_info.value)
240+
241+
242+
class TestMetricInfoFromJson:
243+
"""Test from_json class method"""
244+
245+
def test_from_json_diag_e(self) -> None:
246+
"""Test loading diag_e metric from JSON file"""
247+
with tempfile.NamedTemporaryFile(
248+
mode='w', suffix='.json', delete=False
249+
) as f:
250+
json.dump(
251+
{
252+
'stepsize': 0.5,
253+
'metric_type': 'diag_e',
254+
'inv_metric': [1.0, 2.0, 3.0],
255+
},
256+
f,
257+
)
258+
temp_path = f.name
259+
260+
try:
261+
metric = MetricInfo.from_json(temp_path, chain_id=1)
262+
assert metric.chain_id == 1
263+
assert metric.stepsize == 0.5
264+
assert metric.metric_type == "diag_e"
265+
assert np.array_equal(metric.inv_metric, np.array([1.0, 2.0, 3.0]))
266+
finally:
267+
Path(temp_path).unlink()
268+
269+
def test_from_json_dense_e(self) -> None:
270+
"""Test loading dense_e metric from JSON file"""
271+
with tempfile.NamedTemporaryFile(
272+
mode='w', suffix='.json', delete=False
273+
) as f:
274+
json.dump(
275+
{
276+
'stepsize': 0.3,
277+
'metric_type': 'dense_e',
278+
'inv_metric': [[1.0, 0.5], [0.5, 1.0]],
279+
},
280+
f,
281+
)
282+
temp_path = f.name
283+
284+
try:
285+
metric = MetricInfo.from_json(temp_path, chain_id=2)
286+
assert metric.chain_id == 2
287+
assert metric.stepsize == 0.3
288+
assert metric.metric_type == "dense_e"
289+
assert metric.inv_metric.shape == (2, 2)
290+
finally:
291+
Path(temp_path).unlink()
292+
293+
def test_from_json_invalid_data_raises_error(self) -> None:
294+
"""Test that invalid data in JSON raises ValidationError"""
295+
with tempfile.NamedTemporaryFile(
296+
mode='w', suffix='.json', delete=False
297+
) as f:
298+
json.dump(
299+
{
300+
'stepsize': -0.5, # Invalid: negative stepsize
301+
'metric_type': 'diag_e',
302+
'inv_metric': [1.0, 2.0, 3.0],
303+
},
304+
f,
305+
)
306+
temp_path = f.name
307+
308+
try:
309+
with pytest.raises(ValidationError):
310+
MetricInfo.from_json(temp_path, chain_id=1)
311+
finally:
312+
Path(temp_path).unlink()
313+
314+
def test_from_json_pathlike(self) -> None:
315+
"""Test from_json works with PathLike objects"""
316+
with tempfile.NamedTemporaryFile(
317+
mode='w', suffix='.json', delete=False
318+
) as f:
319+
json.dump(
320+
{
321+
'stepsize': 0.5,
322+
'metric_type': 'unit_e',
323+
'inv_metric': [1.0, 1.0],
324+
},
325+
f,
326+
)
327+
temp_path = Path(f.name)
328+
329+
try:
330+
metric = MetricInfo.from_json(temp_path, chain_id=3)
331+
assert metric.chain_id == 3
332+
assert metric.metric_type == "unit_e"
333+
finally:
334+
temp_path.unlink()

0 commit comments

Comments
 (0)