|
1 | 1 | """ |
2 | 2 | Utility functions for reading the Stan CSV format |
3 | 3 | """ |
| 4 | + |
| 5 | +import io |
4 | 6 | import json |
5 | 7 | import math |
| 8 | +import os |
6 | 9 | import re |
7 | | -from typing import Any, Dict, List, MutableMapping, Optional, TextIO, Union |
| 10 | +import warnings |
| 11 | +from typing import ( |
| 12 | + Any, |
| 13 | + Dict, |
| 14 | + Iterator, |
| 15 | + List, |
| 16 | + MutableMapping, |
| 17 | + Optional, |
| 18 | + TextIO, |
| 19 | + Tuple, |
| 20 | + Union, |
| 21 | +) |
8 | 22 |
|
9 | 23 | import numpy as np |
| 24 | +import numpy.typing as npt |
10 | 25 | import pandas as pd |
11 | 26 |
|
12 | 27 | from cmdstanpy import _CMDSTAN_SAMPLING, _CMDSTAN_THIN, _CMDSTAN_WARMUP |
13 | 28 |
|
14 | 29 |
|
| 30 | +def parse_stan_csv_comments_and_draws( |
| 31 | + stan_csv: Union[str, os.PathLike, Iterator[bytes]], |
| 32 | +) -> Tuple[List[bytes], List[bytes]]: |
| 33 | + """Parses lines of a Stan CSV file into comment lines and draws lines, where |
| 34 | + a draws line is just a non-commented line. |
| 35 | +
|
| 36 | + Returns a (comment_lines, draws_lines) tuple. |
| 37 | + """ |
| 38 | + |
| 39 | + def split_comments_and_draws( |
| 40 | + lines: Iterator[bytes], |
| 41 | + ) -> Tuple[List[bytes], List[bytes]]: |
| 42 | + comment_lines, draws_lines = [], [] |
| 43 | + for line in lines: |
| 44 | + if line.startswith(b"#"): # is comment line |
| 45 | + comment_lines.append(line) |
| 46 | + else: |
| 47 | + draws_lines.append(line) |
| 48 | + return comment_lines, draws_lines |
| 49 | + |
| 50 | + if isinstance(stan_csv, (str, os.PathLike)): |
| 51 | + with open(stan_csv, "rb") as f: |
| 52 | + return split_comments_and_draws(f) |
| 53 | + else: |
| 54 | + return split_comments_and_draws(stan_csv) |
| 55 | + |
| 56 | + |
| 57 | +def csv_bytes_list_to_numpy( |
| 58 | + csv_bytes_list: List[bytes], includes_header: bool = True |
| 59 | +) -> npt.NDArray[np.float64]: |
| 60 | + """Efficiently converts a list of bytes representing whose concatenation |
| 61 | + represents a CSV file into a numpy array. Includes header specifies |
| 62 | + whether the bytes contains an initial header line.""" |
| 63 | + try: |
| 64 | + import polars as pl |
| 65 | + |
| 66 | + try: |
| 67 | + if not csv_bytes_list: |
| 68 | + raise ValueError("No data found to parse") |
| 69 | + num_cols = csv_bytes_list[0].count(b",") + 1 |
| 70 | + out: npt.NDArray[np.float64] = ( |
| 71 | + pl.read_csv( |
| 72 | + io.BytesIO(b"".join(csv_bytes_list)), |
| 73 | + has_header=includes_header, |
| 74 | + schema_overrides=[pl.Float64] * num_cols, |
| 75 | + infer_schema=False, |
| 76 | + ) |
| 77 | + .to_numpy() |
| 78 | + .astype(np.float64) |
| 79 | + ) |
| 80 | + if out.shape[0] == 0: |
| 81 | + raise ValueError("No data found to parse") |
| 82 | + except pl.exceptions.NoDataError as exc: |
| 83 | + raise ValueError("No data found to parse") from exc |
| 84 | + except ImportError: |
| 85 | + with warnings.catch_warnings(): |
| 86 | + warnings.filterwarnings("ignore") |
| 87 | + out = np.loadtxt( |
| 88 | + csv_bytes_list, |
| 89 | + skiprows=int(includes_header), |
| 90 | + delimiter=",", |
| 91 | + dtype=np.float64, |
| 92 | + ndmin=1, |
| 93 | + ) |
| 94 | + if out.shape == (0,): |
| 95 | + raise ValueError("No data found to parse") # pylint: disable=W0707 |
| 96 | + if len(out.shape) == 1: |
| 97 | + out = out.reshape(1, -1) |
| 98 | + |
| 99 | + return out |
| 100 | + |
| 101 | + |
| 102 | +def parse_hmc_adaptation_lines( |
| 103 | + comment_lines: List[bytes], |
| 104 | +) -> Tuple[float, Optional[npt.NDArray[np.float64]]]: |
| 105 | + """Extracts step size/mass matrix information from the Stan CSV comment |
| 106 | + lines by parsing the adaptation section. If the diag_e metric is used, |
| 107 | + the returned mass matrix will be a 1D array of the diagnoal elements, |
| 108 | + if the dense_e metric is used, it will be a 2D array representing the |
| 109 | + entire matrix, and if unit_e is used then None will be returned. |
| 110 | +
|
| 111 | + Returns a (step_size, mass_matrix) tuple""" |
| 112 | + step_size, mass_matrix = None, None |
| 113 | + |
| 114 | + cleaned_lines = (ln.lstrip(b"# ") for ln in comment_lines) |
| 115 | + in_matrix_block = False |
| 116 | + diag_e_metric = False |
| 117 | + matrix_lines = [] |
| 118 | + for line in cleaned_lines: |
| 119 | + if in_matrix_block and line.strip(): |
| 120 | + # Stop when we get to timing block |
| 121 | + if line.startswith(b"Elapsed Time"): |
| 122 | + break |
| 123 | + matrix_lines.append(line) |
| 124 | + elif line.startswith(b"Step size"): |
| 125 | + _, ss_str = line.split(b" = ") |
| 126 | + step_size = float(ss_str) |
| 127 | + elif line.startswith(b"Diagonal") or line.startswith(b"Elements"): |
| 128 | + in_matrix_block = True |
| 129 | + elif line.startswith(b"No free"): |
| 130 | + break |
| 131 | + elif b"diag_e" in line: |
| 132 | + diag_e_metric = True |
| 133 | + if step_size is None: |
| 134 | + raise ValueError("Unable to parse adapated step size") |
| 135 | + if matrix_lines: |
| 136 | + mass_matrix = csv_bytes_list_to_numpy( |
| 137 | + matrix_lines, includes_header=False |
| 138 | + ) |
| 139 | + if diag_e_metric and mass_matrix.shape[0] == 1: |
| 140 | + mass_matrix = mass_matrix[0] |
| 141 | + return step_size, mass_matrix |
| 142 | + |
| 143 | + |
15 | 144 | def check_sampler_csv( |
16 | 145 | path: str, |
17 | 146 | is_fixed_param: bool = False, |
|
0 commit comments