Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions cmdstanpy/utils/stancsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ def csv_bytes_list_to_numpy(
"""Efficiently converts a list of bytes representing whose concatenation
represents a CSV file into a numpy array. Includes header specifies
whether the bytes contains an initial header line."""
if not csv_bytes_list:
return np.empty((0,))
Comment on lines +63 to +64
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure it's worth complicating the logic, but it would make sense to me if a completely empty list was illegal when includes_header=True. I don't think that would ever happen on a valid csv file, though

num_cols = csv_bytes_list[0].count(b",") + 1
try:
import polars as pl

try:
if not csv_bytes_list:
raise ValueError("No data found to parse")
num_cols = csv_bytes_list[0].count(b",") + 1
out: npt.NDArray[np.float64] = (
pl.read_csv(
io.BytesIO(b"".join(csv_bytes_list)),
Expand All @@ -77,10 +77,8 @@ def csv_bytes_list_to_numpy(
.to_numpy()
.astype(np.float64)
)
if out.shape[0] == 0:
raise ValueError("No data found to parse")
except pl.exceptions.NoDataError as exc:
raise ValueError("No data found to parse") from exc
except pl.exceptions.NoDataError:
return np.empty((0, num_cols))
except ImportError:
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
Expand All @@ -91,10 +89,11 @@ def csv_bytes_list_to_numpy(
dtype=np.float64,
ndmin=1,
)
if out.shape == (0,):
raise ValueError("No data found to parse") # pylint: disable=W0707
if len(out.shape) == 1:
out = out.reshape(1, -1)
if out.shape[0] == 0: # No data read
out = np.empty((0, num_cols))
else:
out = out.reshape(1, -1)

return out

Expand Down
10 changes: 10 additions & 0 deletions test/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -2138,3 +2138,13 @@ def test_sample_dense_mass_matrix():
fit = linear_model.sample(data=jdata, metric="dense_e", chains=2)
assert fit.metric is not None
assert fit.metric.shape == (2, 3, 3)


def test_no_output_draws():
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
model = cmdstanpy.CmdStanModel(stan_file=stan)
data = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')

mcmc = model.sample(data=data, iter_sampling=0, save_warmup=False, chains=2)
draws = mcmc.draws()
assert np.array_equal(draws, np.empty((0, 2, len(mcmc.column_names))))
23 changes: 8 additions & 15 deletions test/test_stancsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,10 @@ def test_csv_bytes_to_numpy_with_header_no_polars():
assert np.array_equal(arr_out, expected)


def test_csv_bytes_to_numpy_empty():
lines = [b""]
with pytest.raises(ValueError):
stancsv.csv_bytes_list_to_numpy(lines)


def test_csv_bytes_to_numpy_empty_no_polars():
lines = [b""]
with without_import("polars", cmdstanpy.utils.stancsv):
with pytest.raises(ValueError):
stancsv.csv_bytes_list_to_numpy(lines)
def test_csv_bytes_empty():
lines = []
arr = stancsv.csv_bytes_list_to_numpy(lines)
assert np.array_equal(arr, np.empty((0,)))


def test_csv_bytes_to_numpy_header_no_draws():
Expand All @@ -156,8 +149,8 @@ def test_csv_bytes_to_numpy_header_no_draws():
b"n_leapfrog__,divergent__,energy__,theta\n"
),
]
with pytest.raises(ValueError):
stancsv.csv_bytes_list_to_numpy(lines)
arr = stancsv.csv_bytes_list_to_numpy(lines)
assert arr.shape == (0, 8)


def test_csv_bytes_to_numpy_header_no_draws_no_polars():
Expand All @@ -168,8 +161,8 @@ def test_csv_bytes_to_numpy_header_no_draws_no_polars():
),
]
with without_import("polars", cmdstanpy.utils.stancsv):
with pytest.raises(ValueError):
stancsv.csv_bytes_list_to_numpy(lines)
arr = stancsv.csv_bytes_list_to_numpy(lines)
assert arr.shape == (0, 8)


def test_parse_comments_and_draws():
Expand Down
Loading