Skip to content

Commit 71d97ca

Browse files
authored
Merge pull request #801 from amas0/fix-empty-draws
Revert csv parsing empty data to return empty arrays
2 parents 5383424 + d6e007e commit 71d97ca

File tree

3 files changed

+27
-25
lines changed

3 files changed

+27
-25
lines changed

cmdstanpy/utils/stancsv.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,13 @@ def csv_bytes_list_to_numpy(
6060
"""Efficiently converts a list of bytes representing whose concatenation
6161
represents a CSV file into a numpy array. Includes header specifies
6262
whether the bytes contains an initial header line."""
63+
if not csv_bytes_list:
64+
return np.empty((0,))
65+
num_cols = csv_bytes_list[0].count(b",") + 1
6366
try:
6467
import polars as pl
6568

6669
try:
67-
if not csv_bytes_list:
68-
raise ValueError("No data found to parse")
69-
num_cols = csv_bytes_list[0].count(b",") + 1
7070
out: npt.NDArray[np.float64] = (
7171
pl.read_csv(
7272
io.BytesIO(b"".join(csv_bytes_list)),
@@ -77,10 +77,8 @@ def csv_bytes_list_to_numpy(
7777
.to_numpy()
7878
.astype(np.float64)
7979
)
80-
if out.shape[0] == 0:
81-
raise ValueError("No data found to parse")
82-
except pl.exceptions.NoDataError as exc:
83-
raise ValueError("No data found to parse") from exc
80+
except pl.exceptions.NoDataError:
81+
return np.empty((0, num_cols))
8482
except ImportError:
8583
with warnings.catch_warnings():
8684
warnings.filterwarnings("ignore")
@@ -91,10 +89,11 @@ def csv_bytes_list_to_numpy(
9189
dtype=np.float64,
9290
ndmin=1,
9391
)
94-
if out.shape == (0,):
95-
raise ValueError("No data found to parse") # pylint: disable=W0707
9692
if len(out.shape) == 1:
97-
out = out.reshape(1, -1)
93+
if out.shape[0] == 0: # No data read
94+
out = np.empty((0, num_cols))
95+
else:
96+
out = out.reshape(1, -1)
9897

9998
return out
10099

test/test_sample.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2138,3 +2138,13 @@ def test_sample_dense_mass_matrix():
21382138
fit = linear_model.sample(data=jdata, metric="dense_e", chains=2)
21392139
assert fit.metric is not None
21402140
assert fit.metric.shape == (2, 3, 3)
2141+
2142+
2143+
def test_no_output_draws():
2144+
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
2145+
model = cmdstanpy.CmdStanModel(stan_file=stan)
2146+
data = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
2147+
2148+
mcmc = model.sample(data=data, iter_sampling=0, save_warmup=False, chains=2)
2149+
draws = mcmc.draws()
2150+
assert np.array_equal(draws, np.empty((0, 2, len(mcmc.column_names))))

test/test_stancsv.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -136,17 +136,10 @@ def test_csv_bytes_to_numpy_with_header_no_polars():
136136
assert np.array_equal(arr_out, expected)
137137

138138

139-
def test_csv_bytes_to_numpy_empty():
140-
lines = [b""]
141-
with pytest.raises(ValueError):
142-
stancsv.csv_bytes_list_to_numpy(lines)
143-
144-
145-
def test_csv_bytes_to_numpy_empty_no_polars():
146-
lines = [b""]
147-
with without_import("polars", cmdstanpy.utils.stancsv):
148-
with pytest.raises(ValueError):
149-
stancsv.csv_bytes_list_to_numpy(lines)
139+
def test_csv_bytes_empty():
140+
lines = []
141+
arr = stancsv.csv_bytes_list_to_numpy(lines)
142+
assert np.array_equal(arr, np.empty((0,)))
150143

151144

152145
def test_csv_bytes_to_numpy_header_no_draws():
@@ -156,8 +149,8 @@ def test_csv_bytes_to_numpy_header_no_draws():
156149
b"n_leapfrog__,divergent__,energy__,theta\n"
157150
),
158151
]
159-
with pytest.raises(ValueError):
160-
stancsv.csv_bytes_list_to_numpy(lines)
152+
arr = stancsv.csv_bytes_list_to_numpy(lines)
153+
assert arr.shape == (0, 8)
161154

162155

163156
def test_csv_bytes_to_numpy_header_no_draws_no_polars():
@@ -168,8 +161,8 @@ def test_csv_bytes_to_numpy_header_no_draws_no_polars():
168161
),
169162
]
170163
with without_import("polars", cmdstanpy.utils.stancsv):
171-
with pytest.raises(ValueError):
172-
stancsv.csv_bytes_list_to_numpy(lines)
164+
arr = stancsv.csv_bytes_list_to_numpy(lines)
165+
assert arr.shape == (0, 8)
173166

174167

175168
def test_parse_comments_and_draws():

0 commit comments

Comments
 (0)