diff --git a/cmdstanpy/utils/stancsv.py b/cmdstanpy/utils/stancsv.py index 95d5ad4a..74830994 100644 --- a/cmdstanpy/utils/stancsv.py +++ b/cmdstanpy/utils/stancsv.py @@ -60,13 +60,13 @@ def csv_bytes_list_to_numpy( """Efficiently converts a list of bytes representing whose concatenation represents a CSV file into a numpy array. Includes header specifies whether the bytes contains an initial header line.""" + if not csv_bytes_list: + return np.empty((0,)) + num_cols = csv_bytes_list[0].count(b",") + 1 try: import polars as pl try: - if not csv_bytes_list: - raise ValueError("No data found to parse") - num_cols = csv_bytes_list[0].count(b",") + 1 out: npt.NDArray[np.float64] = ( pl.read_csv( io.BytesIO(b"".join(csv_bytes_list)), @@ -77,10 +77,8 @@ def csv_bytes_list_to_numpy( .to_numpy() .astype(np.float64) ) - if out.shape[0] == 0: - raise ValueError("No data found to parse") - except pl.exceptions.NoDataError as exc: - raise ValueError("No data found to parse") from exc + except pl.exceptions.NoDataError: + return np.empty((0, num_cols)) except ImportError: with warnings.catch_warnings(): warnings.filterwarnings("ignore") @@ -91,10 +89,11 @@ def csv_bytes_list_to_numpy( dtype=np.float64, ndmin=1, ) - if out.shape == (0,): - raise ValueError("No data found to parse") # pylint: disable=W0707 if len(out.shape) == 1: - out = out.reshape(1, -1) + if out.shape[0] == 0: # No data read + out = np.empty((0, num_cols)) + else: + out = out.reshape(1, -1) return out diff --git a/test/test_sample.py b/test/test_sample.py index 0ffb63c8..7f4f531f 100644 --- a/test/test_sample.py +++ b/test/test_sample.py @@ -2138,3 +2138,13 @@ def test_sample_dense_mass_matrix(): fit = linear_model.sample(data=jdata, metric="dense_e", chains=2) assert fit.metric is not None assert fit.metric.shape == (2, 3, 3) + + +def test_no_output_draws(): + stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') + model = cmdstanpy.CmdStanModel(stan_file=stan) + data = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') + + mcmc = model.sample(data=data, iter_sampling=0, save_warmup=False, chains=2) + draws = mcmc.draws() + assert np.array_equal(draws, np.empty((0, 2, len(mcmc.column_names)))) diff --git a/test/test_stancsv.py b/test/test_stancsv.py index 7f3910c4..dc6dd0ee 100644 --- a/test/test_stancsv.py +++ b/test/test_stancsv.py @@ -136,17 +136,10 @@ def test_csv_bytes_to_numpy_with_header_no_polars(): assert np.array_equal(arr_out, expected) -def test_csv_bytes_to_numpy_empty(): - lines = [b""] - with pytest.raises(ValueError): - stancsv.csv_bytes_list_to_numpy(lines) - - -def test_csv_bytes_to_numpy_empty_no_polars(): - lines = [b""] - with without_import("polars", cmdstanpy.utils.stancsv): - with pytest.raises(ValueError): - stancsv.csv_bytes_list_to_numpy(lines) +def test_csv_bytes_empty(): + lines = [] + arr = stancsv.csv_bytes_list_to_numpy(lines) + assert np.array_equal(arr, np.empty((0,))) def test_csv_bytes_to_numpy_header_no_draws(): @@ -156,8 +149,8 @@ def test_csv_bytes_to_numpy_header_no_draws(): b"n_leapfrog__,divergent__,energy__,theta\n" ), ] - with pytest.raises(ValueError): - stancsv.csv_bytes_list_to_numpy(lines) + arr = stancsv.csv_bytes_list_to_numpy(lines) + assert arr.shape == (0, 8) def test_csv_bytes_to_numpy_header_no_draws_no_polars(): @@ -168,8 +161,8 @@ def test_csv_bytes_to_numpy_header_no_draws_no_polars(): ), ] with without_import("polars", cmdstanpy.utils.stancsv): - with pytest.raises(ValueError): - stancsv.csv_bytes_list_to_numpy(lines) + arr = stancsv.csv_bytes_list_to_numpy(lines) + assert arr.shape == (0, 8) def test_parse_comments_and_draws():