Skip to content

Commit 7a214d5

Browse files
committed
Add tests for csv parsing empty data
1 parent ab4afd4 commit 7a214d5

File tree

2 files changed

+18
-16
lines changed

2 files changed

+18
-16
lines changed

test/test_sample.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2138,3 +2138,13 @@ def test_sample_dense_mass_matrix():
21382138
fit = linear_model.sample(data=jdata, metric="dense_e", chains=2)
21392139
assert fit.metric is not None
21402140
assert fit.metric.shape == (2, 3, 3)
2141+
2142+
2143+
def test_no_output_draws():
2144+
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
2145+
model = cmdstanpy.CmdStanModel(stan_file=stan)
2146+
data = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
2147+
2148+
mcmc = model.sample(data=data, iter_sampling=0, save_warmup=False, chains=2)
2149+
draws = mcmc.draws()
2150+
assert np.array_equal(draws, np.empty((0, 2, len(mcmc.column_names))))

test/test_stancsv.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -136,17 +136,10 @@ def test_csv_bytes_to_numpy_with_header_no_polars():
136136
assert np.array_equal(arr_out, expected)
137137

138138

139-
def test_csv_bytes_to_numpy_empty():
140-
lines = [b""]
141-
with pytest.raises(ValueError):
142-
stancsv.csv_bytes_list_to_numpy(lines)
143-
144-
145-
def test_csv_bytes_to_numpy_empty_no_polars():
146-
lines = [b""]
147-
with without_import("polars", cmdstanpy.utils.stancsv):
148-
with pytest.raises(ValueError):
149-
stancsv.csv_bytes_list_to_numpy(lines)
139+
def test_csv_bytes_empty():
140+
lines = []
141+
arr = stancsv.csv_bytes_list_to_numpy(lines)
142+
assert np.array_equal(arr, np.empty((0,)))
150143

151144

152145
def test_csv_bytes_to_numpy_header_no_draws():
@@ -156,8 +149,8 @@ def test_csv_bytes_to_numpy_header_no_draws():
156149
b"n_leapfrog__,divergent__,energy__,theta\n"
157150
),
158151
]
159-
with pytest.raises(ValueError):
160-
stancsv.csv_bytes_list_to_numpy(lines)
152+
arr = stancsv.csv_bytes_list_to_numpy(lines)
153+
assert arr.shape == (0, 8)
161154

162155

163156
def test_csv_bytes_to_numpy_header_no_draws_no_polars():
@@ -167,9 +160,8 @@ def test_csv_bytes_to_numpy_header_no_draws_no_polars():
167160
b"n_leapfrog__,divergent__,energy__,theta\n"
168161
),
169162
]
170-
with without_import("polars", cmdstanpy.utils.stancsv):
171-
with pytest.raises(ValueError):
172-
stancsv.csv_bytes_list_to_numpy(lines)
163+
arr = stancsv.csv_bytes_list_to_numpy(lines)
164+
assert arr.shape == (0, 8)
173165

174166

175167
def test_parse_comments_and_draws():

0 commit comments

Comments
 (0)