Skip to content

Commit 61a61c2

Browse files
committed
Add test for csv round-tripping
1 parent bdec30c commit 61a61c2

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

test/test_sample.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2030,6 +2030,26 @@ def test_tuple_data_in() -> None:
20302030
data_model.sample(data, chains=1, iter_warmup=1, iter_sampling=1)
20312031

20322032

2033+
def test_csv_roundtrip():
2034+
stan = os.path.join(DATAFILES_PATH, 'matrix_var.stan')
2035+
model = CmdStanModel(stan_file=stan)
2036+
fit = model.sample(
2037+
iter_sampling=10, iter_warmup=9, chains=2, save_warmup=True
2038+
)
2039+
z = fit.stan_variable(var="z")
2040+
assert z.shape == (20, 4, 3)
2041+
z_with_warmup = fit.stan_variable(var="z", inc_warmup=True)
2042+
assert z_with_warmup.shape == (38, 4, 3)
2043+
2044+
# mostly just asserting that from_csv always succeeds
2045+
# in parsing latest cmdstan headers
2046+
fit_from_csv = from_csv(fit.runset.csv_files)
2047+
z_from_csv = fit_from_csv.stan_variable(var="z")
2048+
assert z_from_csv.shape == (20, 4, 3)
2049+
z_with_warmup_from_csv = fit.stan_variable(var="z", inc_warmup=True)
2050+
assert z_with_warmup_from_csv.shape == (38, 4, 3)
2051+
2052+
20332053
@pytest.mark.order(before="test_no_xarray")
20342054
def test_serialization(stanfile='bernoulli.stan'):
20352055
# This test must before any test that uses the `without_import` context

0 commit comments

Comments
 (0)