1616from cmdstanpy import _CMDSTAN_SAMPLING , _CMDSTAN_THIN , _CMDSTAN_WARMUP
1717
1818
19- def parse_stan_csv_comments_and_draws (
19+ def parse_comments_header_and_draws (
2020 stan_csv : Union [str , os .PathLike , Iterator [bytes ]],
21- ) -> Tuple [List [bytes ], List [bytes ]]:
21+ ) -> Tuple [List [bytes ], Optional [ str ], List [bytes ]]:
2222 """Parses lines of a Stan CSV file into comment lines and draws lines, where
2323 a draws line is just a non-commented line.
2424
2525 Returns a (comment_lines, draws_lines) tuple.
2626 """
2727
28- def split_comments_and_draws (
28+ def partition_csv (
2929 lines : Iterator [bytes ],
30- ) -> Tuple [List [bytes ], List [bytes ]]:
31- comment_lines , draws_lines = [], []
30+ ) -> Tuple [List [bytes ], Optional [str ], List [bytes ]]:
31+ comment_lines : List [bytes ] = []
32+ draws_lines : List [bytes ] = []
33+ header = None
3234 for line in lines :
3335 if line .startswith (b"#" ): # is comment line
3436 comment_lines .append (line )
37+ elif header is None : # Assumes the header is the first non-comment
38+ header = line .strip ().decode ()
3539 else :
3640 draws_lines .append (line )
37- return comment_lines , draws_lines
41+ return comment_lines , header , draws_lines
3842
3943 if isinstance (stan_csv , (str , os .PathLike )):
4044 with open (stan_csv , "rb" ) as f :
41- return split_comments_and_draws (f )
45+ return partition_csv (f )
4246 else :
43- return split_comments_and_draws (stan_csv )
47+ return partition_csv (stan_csv )
4448
4549
4650def filter_csv_bytes_by_columns (
@@ -58,13 +62,15 @@ def filter_csv_bytes_by_columns(
5862
5963
6064def csv_bytes_list_to_numpy (
61- csv_bytes_list : List [bytes ], includes_header : bool = True
65+ csv_bytes_list : List [bytes ],
6266) -> npt .NDArray [np .float64 ]:
6367 """Efficiently converts a list of bytes representing whose concatenation
64- represents a CSV file into a numpy array. Includes header specifies
65- whether the bytes contains an initial header line."""
68+ represents a CSV file into a numpy array.
69+
70+ Returns a 2D numpy array with shape (n_rows, n_cols). If no data is found,
71+ returns an empty array with shape (0, 0)."""
6672 if not csv_bytes_list :
67- return np .empty ((0 ,))
73+ return np .empty ((0 , 0 ))
6874 num_cols = csv_bytes_list [0 ].count (b"," ) + 1
6975 try :
7076 import polars as pl
@@ -73,30 +79,26 @@ def csv_bytes_list_to_numpy(
7379 out : npt .NDArray [np .float64 ] = (
7480 pl .read_csv (
7581 io .BytesIO (b"" .join (csv_bytes_list )),
76- has_header = includes_header ,
82+ has_header = False ,
7783 schema_overrides = [pl .Float64 ] * num_cols ,
7884 infer_schema = False ,
7985 )
8086 .to_numpy ()
8187 .astype (np .float64 )
8288 )
8389 except pl .exceptions .NoDataError :
84- return np .empty ((0 ,))
90+ return np .empty ((0 , 0 ))
8591 except ImportError :
8692 with warnings .catch_warnings ():
8793 warnings .filterwarnings ("ignore" )
8894 out = np .loadtxt (
8995 csv_bytes_list ,
90- skiprows = int (includes_header ),
9196 delimiter = "," ,
9297 dtype = np .float64 ,
93- ndmin = 1 ,
98+ ndmin = 2 ,
9499 )
95- if len (out .shape ) == 1 :
96- if out .shape [0 ] == 0 : # No data read
97- out = np .empty ((0 , num_cols ))
98- else :
99- out = out .reshape (1 , - 1 )
100+ if out .shape [0 ] == 0 : # No data read
101+ out = np .empty ((0 , 0 ))
100102
101103 return out
102104
@@ -133,9 +135,7 @@ def parse_hmc_adaptation_lines(
133135 elif b"diag_e" in line :
134136 diag_e_metric = True
135137 if matrix_lines :
136- mass_matrix = csv_bytes_list_to_numpy (
137- matrix_lines , includes_header = False
138- )
138+ mass_matrix = csv_bytes_list_to_numpy (matrix_lines )
139139 if diag_e_metric and mass_matrix .shape [0 ] == 1 :
140140 mass_matrix = mass_matrix [0 ]
141141 return step_size , mass_matrix
@@ -188,46 +188,30 @@ def parse_config(
188188 return out
189189
190190
191- def extract_header_line (draws_lines : List [bytes ]) -> str :
192- """Attempts to extract the header line from the draw lines list.
193-
194- Returns the raw header line as a string"""
195- if not draws_lines :
196- raise ValueError ("Attempting to parse header from empty list" )
197-
198- first_line = draws_lines [0 ]
199- if not first_line :
200- raise ValueError ("Empty first line when attempting to parse header" )
201- first_char = first_line [0 ]
202-
203- if first_char in b"1234567890-" :
204- raise ValueError ("Header line appears to be numeric data" )
205-
206- return first_line .decode ().strip ()
207-
208-
209191def parse_header (header : str ) -> Tuple [str , ...]:
210192 """Returns munged variable names from a Stan csv header line"""
211193 return tuple (munge_varname (name ) for name in header .split ("," ))
212194
213195
214- def extract_config_and_header_info (
215- comment_lines : List [bytes ], draws_lines : List [ bytes ]
196+ def construct_config_header_dict (
197+ comment_lines : List [bytes ], header : Optional [ str ]
216198) -> Dict [str , Union [str , int , float , Tuple [str , ...]]]:
217199 """Extracts config and header info from comment/draws lines parsed
218200 from a Stan CSV file."""
219201 config = parse_config (comment_lines )
220- raw_header = extract_header_line ( draws_lines )
221- return {
222- ** config ,
223- ** { "raw_header" : raw_header , " column_names": parse_header (raw_header )},
224- }
202+ out : Dict [ str , Union [ str , int , float , Tuple [ str , ...]]] = { ** config }
203+ if header :
204+ out [ "raw_header" ] = header
205+ out [ " column_names"] = parse_header (header )
206+ return out
225207
226208
227209def parse_variational_eta (comment_lines : List [bytes ]) -> float :
228210 """Extracts the variational eta parameter from stancsv comment lines"""
229211 for i , line in enumerate (comment_lines ):
230- if line .startswith (b"# Stepsize adaptation" ):
212+ if line .startswith (b"# Stepsize adaptation" ) and (
213+ i + 1 < len (comment_lines ) # Ensure i + 1 is in bounds
214+ ):
231215 eta_line = comment_lines [i + 1 ]
232216 break
233217 else :
@@ -240,18 +224,18 @@ def parse_variational_eta(comment_lines: List[bytes]) -> float:
240224
241225
242226def extract_max_treedepth_and_divergence_counts (
243- draws_lines : List [bytes ], max_treedepth : int , warmup_draws : int
227+ header : str , draws_lines : List [bytes ], max_treedepth : int , warmup_draws : int
244228) -> Tuple [int , int ]:
245229 """Extracts the max treedepth and divergence counts from the draw lines
246230 of the MCMC stan csv output."""
247231 if len (draws_lines ) <= 1 : # Empty draws
248232 return 0 , 0
249- column_names = draws_lines [ 0 ]. strip (). split (b "," )
233+ column_names = header . split ("," )
250234
251235 try :
252236 indexes_to_keep = [
253- column_names .index (b "treedepth__" ),
254- column_names .index (b "divergent__" ),
237+ column_names .index ("treedepth__" ),
238+ column_names .index ("divergent__" ),
255239 ]
256240 except ValueError :
257241 # Throws if treedepth/divergent columns not recorded
@@ -260,24 +244,22 @@ def extract_max_treedepth_and_divergence_counts(
260244 sampling_draws = draws_lines [1 + warmup_draws :]
261245
262246 filtered = filter_csv_bytes_by_columns (sampling_draws , indexes_to_keep )
263- arr = csv_bytes_list_to_numpy (filtered , includes_header = False ).astype (int )
247+ arr = csv_bytes_list_to_numpy (filtered ).astype (int )
264248
265249 num_max_treedepth = np .sum (arr [:, 0 ] == max_treedepth )
266250 num_divergences = np .sum (arr [:, 1 ])
267251 return num_max_treedepth , num_divergences
268252
269253
270- def is_sneaky_fixed_param (header_line : bytes ) -> bool :
254+ def is_sneaky_fixed_param (header : str ) -> bool :
271255 """Returns True if the header line indicates that the sampler
272256 ran with the fixed_param sampler automatically, despite the
273257 algorithm listed as 'hmc'.
274258
275259 See issue #805"""
276- num_dunder_cols = sum (
277- col .endswith (b"__" ) for col in header_line .split (b"," )
278- )
260+ num_dunder_cols = sum (col .endswith ("__" ) for col in header .split ("," ))
279261
280- return (num_dunder_cols < 7 ) and b "lp__" in header_line
262+ return (num_dunder_cols < 7 ) and "lp__" in header
281263
282264
283265def count_warmup_and_sampling_draws (
@@ -300,7 +282,9 @@ def determine_draw_counts(lines: Iterator[bytes]) -> Tuple[int, int]:
300282 if line .startswith (b"lp__" ):
301283 header_line_idx = i
302284 if not is_fixed_param :
303- is_fixed_param = is_sneaky_fixed_param (line )
285+ is_fixed_param = is_sneaky_fixed_param (
286+ line .strip ().decode ()
287+ )
304288 continue
305289
306290 if not is_fixed_param and adaptation_block_idx is None :
@@ -339,7 +323,9 @@ def determine_draw_counts(lines: Iterator[bytes]) -> Tuple[int, int]:
339323 return determine_draw_counts (stan_csv )
340324
341325
342- def raise_on_inconsistent_draws_shape (draw_lines : List [bytes ]) -> None :
326+ def raise_on_inconsistent_draws_shape (
327+ header : str , draw_lines : List [bytes ]
328+ ) -> None :
343329 """Throws a ValueError if any draws are found to have an inconsistent
344330 shape, i.e. too many/few columns compared to the header"""
345331
@@ -350,9 +336,8 @@ def column_count(ln: bytes) -> int:
350336 if not draw_lines :
351337 return
352338
353- header , * draws = draw_lines
354- num_cols = column_count (header )
355- for i , draw in enumerate (draws , start = 1 ):
339+ num_cols = column_count (header .encode ())
340+ for i , draw in enumerate (draw_lines , start = 1 ):
356341 if (draw_size := column_count (draw )) != num_cols :
357342 raise ValueError (
358343 f"line { i } : bad draw, expecting { num_cols } items,"
@@ -488,18 +473,22 @@ def parse_sampler_metadata_from_csv(
488473) -> Dict [str , Union [int , float , str , Tuple [str , ...], Dict [str , float ]]]:
489474 """Parses sampling metadata from a given Stan CSV path for a sample run"""
490475 try :
491- comments , draws = parse_stan_csv_comments_and_draws (path )
492- raise_on_inconsistent_draws_shape (draws )
493- config = extract_config_and_header_info (comments , draws )
476+ comments , header , draws = parse_comments_header_and_draws (path )
477+ if header is None :
478+ raise ValueError ("No header line found in stan csv" )
479+ raise_on_inconsistent_draws_shape (header , draws )
480+ config = construct_config_header_dict (comments , header )
494481 num_warmup , num_sampling = count_warmup_and_sampling_draws (path )
495482 timings = parse_timing_lines (comments )
496- if (config ['algorithm' ] != 'fixed_param' ) and not is_sneaky_fixed_param (
497- draws [0 ]
483+ if (
484+ (config ['algorithm' ] != 'fixed_param' )
485+ and header
486+ and not is_sneaky_fixed_param (header )
498487 ):
499488 raise_on_invalid_adaptation_block (comments )
500489 max_depth : int = config ["max_depth" ] # type: ignore
501490 max_tree_hits , divs = extract_max_treedepth_and_divergence_counts (
502- draws , max_depth , num_warmup
491+ header , draws , max_depth , num_warmup
503492 )
504493 else :
505494 max_tree_hits , divs = 0 , 0
0 commit comments