diff --git a/cmdstanpy/stanfit/mcmc.py b/cmdstanpy/stanfit/mcmc.py index 8c93c8e3..f96ff023 100644 --- a/cmdstanpy/stanfit/mcmc.py +++ b/cmdstanpy/stanfit/mcmc.py @@ -97,8 +97,9 @@ def __init__( self._max_treedepths: np.ndarray = np.zeros( self.runset.chains, dtype=int ) + self._chain_time: List[Dict[str, float]] = [] - # info from CSV initial comments and header + # info from CSV header and initial and final comment blocks config = self._validate_csv_files() self._metadata: InferenceMetadata = InferenceMetadata(config) if not self._is_fixed_param: @@ -240,6 +241,14 @@ def max_treedepths(self) -> Optional[np.ndarray]: """ return self._max_treedepths if not self._is_fixed_param else None + @property + def time(self) -> List[Dict[str, float]]: + """ + List of per-chain time info scraped from CSV file. + Each chain has dict with keys "warmup", "sampling", "total". + """ + return self._chain_time + def draws( self, *, inc_warmup: bool = False, concat_chains: bool = False ) -> np.ndarray: @@ -301,6 +310,7 @@ def _validate_csv_files(self) -> Dict[str, Any]: save_warmup=self._save_warmup, thin=self._thin, ) + self._chain_time.append(dzero['time']) # type: ignore if not self._is_fixed_param: self._divergences[i] = dzero['ct_divergences'] self._max_treedepths[i] = dzero['ct_max_treedepth'] @@ -313,6 +323,7 @@ def _validate_csv_files(self) -> Dict[str, Any]: save_warmup=self._save_warmup, thin=self._thin, ) + self._chain_time.append(drest['time']) # type: ignore for key in dzero: # check args that matter for parsing, plus name, version if ( diff --git a/cmdstanpy/utils/stancsv.py b/cmdstanpy/utils/stancsv.py index b7a3b21c..a50d9657 100644 --- a/cmdstanpy/utils/stancsv.py +++ b/cmdstanpy/utils/stancsv.py @@ -79,6 +79,7 @@ def scan_sampler_csv(path: str, is_fixed_param: bool = False) -> Dict[str, Any]: lineno = scan_warmup_iters(fd, dict, lineno) lineno = scan_hmc_params(fd, dict, lineno) lineno = scan_sampling_iters(fd, dict, lineno, is_fixed_param) + lineno = scan_time(fd, dict, lineno) except ValueError as e: raise ValueError("Error in reading csv file: " + path) from e return dict @@ -381,6 +382,66 @@ def scan_sampling_iters( return lineno +def scan_time(fd: TextIO, config_dict: Dict[str, Any], lineno: int) -> int: + """ + Scan time information from the trailing comment lines in a Stan CSV file. + + # Elapsed Time: 0.001332 seconds (Warm-up) + # 0.000249 seconds (Sampling) + # 0.001581 seconds (Total) + + + It extracts the time values and saves them in the config_dict: key 'time', + value a dictionary with keys 'warmup', 'sampling', and 'total'. + Returns the updated line number after reading the time info. + + :param fd: Open file descriptor at comment row following all sample data. + :param config_dict: Dictionary to which the time info is added. + :param lineno: Current line number + """ + time = {} + keys = ['warmup', 'sampling', 'total'] + while True: + pos = fd.tell() + line = fd.readline() + if not line: + break + lineno += 1 + stripped = line.strip() + if not stripped.startswith('#'): + fd.seek(pos) + lineno -= 1 + break + content = stripped.lstrip('#').strip() + if not content: + continue + tokens = content.split() + if len(tokens) < 3: + raise ValueError(f"Invalid time at line {lineno}: {content}") + if 'Warm-up' in content: + key = 'warmup' + time_str = tokens[2] + elif 'Sampling' in content: + key = 'sampling' + time_str = tokens[0] + elif 'Total' in content: + key = 'total' + time_str = tokens[0] + else: + raise ValueError(f"Invalid time at line {lineno}: {content}") + try: + t = float(time_str) + except ValueError as e: + raise ValueError(f"Invalid time at line {lineno}: {content}") from e + time[key] = t + + if not all(key in time for key in keys): + raise ValueError(f"Invalid time, stopped at {lineno}") + + config_dict['time'] = time + return lineno + + def read_metric(path: str) -> List[int]: """ Read metric file in JSON or Rdump format. diff --git a/test/test_sample.py b/test/test_sample.py index 944ce7a6..3c987ddc 100644 --- a/test/test_sample.py +++ b/test/test_sample.py @@ -1714,6 +1714,12 @@ def test_metadata() -> None: assert fit.column_names == col_names assert fit.metric_type == 'diag_e' + assert len(fit.time) == 4 + for i in range(4): + assert 'warmup' in fit.time[i].keys() + assert 'sampling' in fit.time[i].keys() + assert 'total' in fit.time[i].keys() + assert fit.metadata.cmdstan_config['num_samples'] == 100 assert fit.metadata.cmdstan_config['thin'] == 1 assert fit.metadata.cmdstan_config['algorithm'] == 'hmc' diff --git a/test/test_utils.py b/test/test_utils.py index 6269daad..38429b03 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -699,3 +699,57 @@ def test_munge_varnames() -> None: var = 'y.2.3:1.2:5:6' assert stancsv.munge_varname(var) == 'y[2,3].1[2].5.6' + + +def test_scan_time_normal() -> None: + csv_content = ( + "# Elapsed Time: 0.005 seconds (Warm-up)\n" + "# 0 seconds (Sampling)\n" + "# 0.005 seconds (Total)\n" + ) + fd = io.StringIO(csv_content) + config_dict = {} + start_line = 0 + final_line = stancsv.scan_time(fd, config_dict, start_line) + assert final_line == 3 + expected = {'warmup': 0.005, 'sampling': 0.0, 'total': 0.005} + assert config_dict.get('time') == expected + + +def test_scan_time_no_timing() -> None: + csv_content = ( + "# merrily we roll along\n" + "# roll along\n" + "# very merrily we roll along\n" + ) + fd = io.StringIO(csv_content) + config_dict = {} + start_line = 0 + with pytest.raises(ValueError, match="Invalid time"): + stancsv.scan_time(fd, config_dict, start_line) + + +def test_scan_time_invalid_value() -> None: + csv_content = ( + "# Elapsed Time: abc seconds (Warm-up)\n" + "# 0.200 seconds (Sampling)\n" + "# 0.300 seconds (Total)\n" + ) + fd = io.StringIO(csv_content) + config_dict = {} + start_line = 0 + with pytest.raises(ValueError, match="Invalid time"): + stancsv.scan_time(fd, config_dict, start_line) + + +def test_scan_time_invalid_string() -> None: + csv_content = ( + "# Elapsed Time: 0.22 seconds (foo)\n" + "# 0.200 seconds (Sampling)\n" + "# 0.300 seconds (Total)\n" + ) + fd = io.StringIO(csv_content) + config_dict = {} + start_line = 0 + with pytest.raises(ValueError, match="Invalid time"): + stancsv.scan_time(fd, config_dict, start_line)