Skip to content

Commit 37080af

Browse files
committed
Convert MetricInfo.inv_metric to native Python types
1 parent 2497754 commit 37080af

File tree

2 files changed

+58
-54
lines changed

2 files changed

+58
-54
lines changed

cmdstanpy/stanfit/metadata.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import os
88
from typing import Any, Iterator, Literal
99

10-
import numpy as np
1110
import stanio
1211
from pydantic import BaseModel, field_validator, model_validator
1312

@@ -92,15 +91,7 @@ class MetricInfo(BaseModel):
9291

9392
stepsize: float
9493
metric_type: Literal["diag_e", "dense_e", "unit_e"]
95-
inv_metric: np.ndarray
96-
97-
# allows ndarray as pydantic attribute
98-
model_config = {"arbitrary_types_allowed": True}
99-
100-
@field_validator("inv_metric", mode="before")
101-
@classmethod
102-
def convert_inv_metric(cls, v: Any) -> np.ndarray:
103-
return np.asarray(v)
94+
inv_metric: list[float] | list[list[float]]
10495

10596
@field_validator("stepsize")
10697
@classmethod
@@ -111,17 +102,26 @@ def validate_stepsize(cls, v: float) -> float:
111102

112103
@model_validator(mode="after")
113104
def validate_inv_metric_shape(self) -> MetricInfo:
114-
if (
115-
self.metric_type in ("diag_e", "unit_e")
116-
and self.inv_metric.ndim != 1
117-
):
105+
if not self.inv_metric: # Empty inv_metric, e.g. from no parameters
106+
return self
107+
108+
is_1d = isinstance(self.inv_metric[0], float)
109+
110+
if self.metric_type in ("diag_e", "unit_e") and not is_1d:
118111
raise ValueError(
119112
"inv_metric must be 1D for diag_e and unit_e metric type"
120113
)
121114
if self.metric_type == "dense_e":
122-
if self.inv_metric.ndim != 2:
115+
if is_1d:
123116
raise ValueError("Dense inv_metric must be 2D")
124-
if self.inv_metric.shape[0] != self.inv_metric.shape[1]:
117+
118+
if any(not row for row in self.inv_metric):
119+
raise ValueError("Dense inv_metric cannot contain empty rows")
120+
121+
n_rows = len(self.inv_metric)
122+
if not all(
123+
len(row) == n_rows for row in self.inv_metric # type: ignore
124+
):
125125
raise ValueError("Dense inv_metric must be square")
126126

127127
return self

test/test_metadata.py

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import tempfile
77
from pathlib import Path
88

9-
import numpy as np
109
import pytest
1110
from pydantic import ValidationError
1211

@@ -82,61 +81,65 @@ class TestMetricInfoValidators:
8281
"""Test custom validators for MetricInfo model"""
8382

8483
def test_valid_diag_e_metric(self) -> None:
85-
"""Test valid diag_e metric with 1D array"""
84+
"""Test valid diag_e metric with 1D list"""
8685
metric = MetricInfo(
8786
stepsize=0.5,
8887
metric_type="diag_e",
89-
inv_metric=[1.0, 2.0, 3.0], # type: ignore
88+
inv_metric=[1.0, 2.0, 3.0],
9089
)
9190
assert metric.stepsize == 0.5
92-
assert isinstance(metric.inv_metric, np.ndarray)
93-
assert metric.inv_metric.ndim == 1
91+
assert isinstance(metric.inv_metric, list)
92+
assert isinstance(metric.inv_metric[0], float)
93+
assert len(metric.inv_metric) == 3
9494

9595
def test_valid_unit_e_metric(self) -> None:
96-
"""Test valid unit_e metric with 1D array"""
96+
"""Test valid unit_e metric with 1D list"""
9797
metric = MetricInfo(
9898
stepsize=0.1,
9999
metric_type="unit_e",
100-
inv_metric=[1.0, 1.0, 1.0], # type: ignore
100+
inv_metric=[1.0, 1.0, 1.0],
101101
)
102102
assert metric.metric_type == "unit_e"
103-
assert metric.inv_metric.ndim == 1
103+
assert isinstance(metric.inv_metric[0], float)
104+
assert len(metric.inv_metric) == 3
104105

105106
def test_valid_dense_e_metric(self) -> None:
106-
"""Test valid dense_e metric with 2D square array"""
107+
"""Test valid dense_e metric with 2D square list"""
107108
metric = MetricInfo(
108109
stepsize=0.3,
109110
metric_type="dense_e",
110-
inv_metric=[[1.0, 0.5], [0.5, 1.0]], # type: ignore
111+
inv_metric=[[1.0, 0.5], [0.5, 1.0]],
111112
)
112113
assert metric.metric_type == "dense_e"
113-
assert metric.inv_metric.ndim == 2
114-
assert metric.inv_metric.shape == (2, 2)
114+
assert isinstance(metric.inv_metric[0], list)
115+
assert len(metric.inv_metric) == 2
116+
assert len(metric.inv_metric[0]) == 2
115117

116-
def test_convert_inv_metric_from_list(self) -> None:
117-
"""Test that inv_metric is converted to numpy array from list"""
118+
def test_inv_metric_stays_as_list(self) -> None:
119+
"""Test that inv_metric remains as list type"""
118120
metric = MetricInfo(
119121
stepsize=0.5,
120122
metric_type="diag_e",
121-
inv_metric=[1.0, 2.0, 3.0], # type: ignore
123+
inv_metric=[1.0, 2.0, 3.0],
122124
)
123-
assert isinstance(metric.inv_metric, np.ndarray)
125+
assert isinstance(metric.inv_metric, list)
124126

125-
def test_convert_inv_metric_from_nested_list(self) -> None:
126-
"""Test that inv_metric is converted to numpy array from nested list"""
127+
def test_inv_metric_nested_list(self) -> None:
128+
"""Test that inv_metric handles nested lists correctly"""
127129
metric = MetricInfo(
128130
stepsize=0.5,
129131
metric_type="dense_e",
130-
inv_metric=[[1.0, 0.0], [0.0, 1.0]], # type: ignore
132+
inv_metric=[[1.0, 0.0], [0.0, 1.0]],
131133
)
132-
assert isinstance(metric.inv_metric, np.ndarray)
134+
assert isinstance(metric.inv_metric, list)
135+
assert isinstance(metric.inv_metric[0], list)
133136

134137
def test_stepsize_positive(self) -> None:
135138
"""Test valid positive stepsize"""
136139
metric = MetricInfo(
137140
stepsize=0.5,
138141
metric_type="diag_e",
139-
inv_metric=[1.0], # type: ignore
142+
inv_metric=[1.0],
140143
)
141144
assert metric.stepsize == 0.5
142145

@@ -145,7 +148,7 @@ def test_stepsize_nan_allowed(self) -> None:
145148
metric = MetricInfo(
146149
stepsize=math.nan,
147150
metric_type="diag_e",
148-
inv_metric=[1.0], # type: ignore
151+
inv_metric=[1.0],
149152
)
150153
assert math.isnan(metric.stepsize)
151154

@@ -155,7 +158,7 @@ def test_stepsize_zero_raises_error(self) -> None:
155158
MetricInfo(
156159
stepsize=0.0,
157160
metric_type="diag_e",
158-
inv_metric=[1.0], # type: ignore
161+
inv_metric=[1.0],
159162
)
160163
assert "stepsize must be greater than 0 or NaN" in str(exc_info.value)
161164

@@ -165,51 +168,51 @@ def test_stepsize_negative_raises_error(self) -> None:
165168
MetricInfo(
166169
stepsize=-0.5,
167170
metric_type="diag_e",
168-
inv_metric=[1.0], # type: ignore
171+
inv_metric=[1.0],
169172
)
170173
assert "stepsize must be greater than 0 or NaN" in str(exc_info.value)
171174

172-
def test_diag_e_with_2d_array_raises_error(self) -> None:
173-
"""Test that diag_e with 2D array raises ValueError"""
175+
def test_diag_e_with_2d_list_raises_error(self) -> None:
176+
"""Test that diag_e with 2D list raises ValueError"""
174177
with pytest.raises(ValidationError) as exc_info:
175178
MetricInfo(
176179
stepsize=0.5,
177180
metric_type="diag_e",
178-
inv_metric=[[1.0, 2.0]], # type: ignore
181+
inv_metric=[[1.0, 2.0]],
179182
)
180183
assert "inv_metric must be 1D for diag_e and unit_e" in str(
181184
exc_info.value
182185
)
183186

184-
def test_unit_e_with_2d_array_raises_error(self) -> None:
185-
"""Test that unit_e with 2D array raises ValueError"""
187+
def test_unit_e_with_2d_list_raises_error(self) -> None:
188+
"""Test that unit_e with 2D list raises ValueError"""
186189
with pytest.raises(ValidationError) as exc_info:
187190
MetricInfo(
188191
stepsize=0.5,
189192
metric_type="unit_e",
190-
inv_metric=[[1.0], [1.0]], # type: ignore
193+
inv_metric=[[1.0], [1.0]],
191194
)
192195
assert "inv_metric must be 1D for diag_e and unit_e" in str(
193196
exc_info.value
194197
)
195198

196-
def test_dense_e_with_1d_array_raises_error(self) -> None:
197-
"""Test that dense_e with 1D array raises ValueError"""
199+
def test_dense_e_with_1d_list_raises_error(self) -> None:
200+
"""Test that dense_e with 1D list raises ValueError"""
198201
with pytest.raises(ValidationError) as exc_info:
199202
MetricInfo(
200203
stepsize=0.5,
201204
metric_type="dense_e",
202-
inv_metric=[1.0, 2.0], # type: ignore
205+
inv_metric=[1.0, 2.0],
203206
)
204207
assert "Dense inv_metric must be 2D" in str(exc_info.value)
205208

206209
def test_dense_e_non_square_raises_error(self) -> None:
207-
"""Test that dense_e with non-square array raises ValueError"""
210+
"""Test that dense_e with non-square list raises ValueError"""
208211
with pytest.raises(ValidationError) as exc_info:
209212
MetricInfo(
210213
stepsize=0.5,
211214
metric_type="dense_e",
212-
inv_metric=[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], # type: ignore
215+
inv_metric=[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
213216
)
214217
assert "Dense inv_metric must be square" in str(exc_info.value)
215218

@@ -237,7 +240,7 @@ def test_from_json_diag_e(self) -> None:
237240
metric = MetricInfo.model_validate_json(f.read())
238241
assert metric.stepsize == 0.5
239242
assert metric.metric_type == "diag_e"
240-
assert np.array_equal(metric.inv_metric, np.array([1.0, 2.0, 3.0]))
243+
assert metric.inv_metric == [1.0, 2.0, 3.0]
241244
finally:
242245
Path(temp_path).unlink()
243246

@@ -261,7 +264,8 @@ def test_from_json_dense_e(self) -> None:
261264
metric = MetricInfo.model_validate_json(f.read())
262265
assert metric.stepsize == 0.3
263266
assert metric.metric_type == "dense_e"
264-
assert metric.inv_metric.shape == (2, 2)
267+
assert len(metric.inv_metric) == 2
268+
assert len(metric.inv_metric[0]) == 2
265269
finally:
266270
Path(temp_path).unlink()
267271

@@ -314,7 +318,7 @@ def test_invalid_metric_type_raises_error(self) -> None:
314318
MetricInfo(
315319
stepsize=0.5,
316320
metric_type="not_a_metric", # type: ignore
317-
inv_metric=[1.0], # type: ignore
321+
inv_metric=[1.0],
318322
)
319323

320324
def test_from_json_invalid_metric_type_raises_error(self) -> None:

0 commit comments

Comments
 (0)