Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 1 addition & 10 deletions cmdstanpy/cmdstan_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
from numpy.random import default_rng

from cmdstanpy.utils import cmdstan_path, cmdstan_version_before, get_logger
from cmdstanpy.utils import get_logger

OptionalPath = str | os.PathLike | None

Expand Down Expand Up @@ -748,15 +748,6 @@ def validate(self) -> None:
'Argument "sig_figs" must be an integer between 1 and 18,'
' found {}'.format(self.sig_figs)
)
# TODO: remove at some future release
if cmdstan_version_before(2, 25):
self.sig_figs = None
get_logger().warning(
'Argument "sig_figs" invalid for CmdStan versions < 2.25, '
'using version %s in directory %s',
os.path.basename(cmdstan_path()),
os.path.dirname(cmdstan_path()),
)

if self.seed is None:
rng = default_rng()
Expand Down
45 changes: 8 additions & 37 deletions cmdstanpy/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,7 @@
from typing import Any, Iterable

from cmdstanpy.utils import get_logger
from cmdstanpy.utils.cmdstan import (
EXTENSION,
cmdstan_path,
cmdstan_version,
cmdstan_version_before,
stanc_path,
)
from cmdstanpy.utils.cmdstan import EXTENSION, cmdstan_path, stanc_path
from cmdstanpy.utils.command import do_command
from cmdstanpy.utils.filesystem import SanitizedOrTmpFilePath

Expand Down Expand Up @@ -463,38 +457,15 @@ def format_stan_file(
)

if canonicalize:
if cmdstan_version_before(2, 29):
if isinstance(canonicalize, bool):
cmd.append('--print-canonical')
else:
raise ValueError(
"Invalid arguments passed for current CmdStan"
+ " version({})\n".format(
cmdstan_version() or "Unknown"
)
+ "--canonicalize requires 2.29 or higher"
)
if isinstance(canonicalize, str):
cmd.append('--canonicalize=' + canonicalize)
elif isinstance(canonicalize, Iterable):
cmd.append('--canonicalize=' + ','.join(canonicalize))
else:
if isinstance(canonicalize, str):
cmd.append('--canonicalize=' + canonicalize)
elif isinstance(canonicalize, Iterable):
cmd.append('--canonicalize=' + ','.join(canonicalize))
else:
cmd.append('--print-canonical')

# before 2.29, having both --print-canonical
# and --auto-format printed twice
if not (cmdstan_version_before(2, 29) and canonicalize):
cmd.append('--auto-format')
cmd.append('--print-canonical')

if not cmdstan_version_before(2, 29):
cmd.append(f'--max-line-length={max_line_length}')
elif max_line_length != 78:
raise ValueError(
"Invalid arguments passed for current CmdStan version"
+ " ({})\n".format(cmdstan_version() or "Unknown")
+ "--max-line-length requires 2.29 or higher"
)
cmd.append('--auto-format')
cmd.append(f'--max-line-length={max_line_length}')

out = subprocess.run(cmd, capture_output=True, text=True, check=True)
if out.stderr:
Expand Down
2 changes: 1 addition & 1 deletion cmdstanpy/install_cxx_toolchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def get_toolchain_name() -> str:
return ''


# TODO(2.0): drop 3.5 support
# TODO(2.0): consider something other than RTools
def get_url(version: str) -> str:
"""Return URL for toolchain."""
url = ''
Expand Down
147 changes: 48 additions & 99 deletions cmdstanpy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,16 @@

import io
import os
import platform
import re
import shutil
import subprocess
import sys
import tempfile
import threading
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
from io import StringIO
from multiprocessing import cpu_count
from typing import Any, Callable, Mapping, Sequence, TypeVar
from typing import Any, Callable, Mapping, Sequence

import numpy as np
import pandas as pd
Expand All @@ -37,15 +35,12 @@
CmdStanMLE,
CmdStanPathfinder,
CmdStanVB,
PrevFit,
RunSet,
from_csv,
)
from cmdstanpy.utils import (
cmdstan_path,
cmdstan_version_before,
do_command,
get_logger,
)
from cmdstanpy.utils import do_command, get_logger
from cmdstanpy.utils.cmdstan import cmdstan_version_before, windows_tbb_path
from cmdstanpy.utils.filesystem import (
temp_inits,
temp_metrics,
Expand All @@ -56,7 +51,6 @@
from . import progress as progbar

OptionalPath = str | os.PathLike | None
Fit = TypeVar('Fit', CmdStanMCMC, CmdStanMLE, CmdStanVB)


class CmdStanModel:
Expand Down Expand Up @@ -118,6 +112,8 @@ def __init__(

self._fixed_param = False

windows_tbb_path()

if exe_file is not None:
self._exe_file = os.path.realpath(os.path.expanduser(exe_file))
if not os.path.exists(self._exe_file):
Expand Down Expand Up @@ -164,33 +160,26 @@ def __init__(
)

# try to detect models w/out parameters, needed for sampler
if (not cmdstan_version_before(2, 27)) and cmdstan_version_before(
2, 36
):
if cmdstan_version_before(2, 36):
model_info = self.src_info()
if 'parameters' in model_info:
self._fixed_param |= len(model_info['parameters']) == 0

if platform.system() == 'Windows':
try:
do_command(['where.exe', 'tbb.dll'], fd_out=None)
except RuntimeError:
# Add tbb to the $PATH on Windows
libtbb = os.environ.get('STAN_TBB')
if libtbb is None:
libtbb = os.path.join(
cmdstan_path(), 'stan', 'lib', 'stan_math', 'lib', 'tbb'
)
get_logger().debug("Adding TBB (%s) to PATH", libtbb)
os.environ['PATH'] = ';'.join(
list(
OrderedDict.fromkeys(
[libtbb] + os.environ.get('PATH', '').split(';')
)
)
)
else:
get_logger().debug("TBB already found in load path")
# check CmdStan version compatibility
exe_info = None
try:
exe_info = self.exe_info()
# pylint: disable=broad-except
except Exception as e:
get_logger().warning(
'Could not get exe info for model %s, error: %s',
self._name,
str(e),
)
if cmdstan_version_before(2, 35, exe_info):
raise RuntimeError(
"This version of CmdStanPy requires CmdStan 2.35 or higher."
)

def __repr__(self) -> str:
return (
Expand Down Expand Up @@ -238,7 +227,7 @@ def src_info(self) -> dict[str, Any]:
If stanc is older than 2.27 or if the stan
file cannot be found, returns an empty dictionary.
"""
if self.stan_file is None or cmdstan_version_before(2, 27):
if self.stan_file is None:
return {}
return compilation.src_info(str(self.stan_file), self._stanc_options)

Expand Down Expand Up @@ -404,12 +393,6 @@ def optimize(
jacobian=jacobian,
)

if jacobian and cmdstan_version_before(2, 32, self.exe_info()):
raise ValueError(
"Jacobian adjustment for optimization is only supported "
"in CmdStan 2.32 and above."
)

with (
temp_single_json(data) as _data,
temp_inits(inits, allow_multiple=False) as _inits,
Expand Down Expand Up @@ -734,34 +717,23 @@ def sample(
if chains == 1:
force_one_process_per_chain = True

if (
force_one_process_per_chain is None
and not cmdstan_version_before(2, 28, info_dict)
and stan_threads == 'true'
):
if force_one_process_per_chain is None and stan_threads == 'true':
one_process_per_chain = False
num_threads = parallel_chains * num_threads
parallel_procs = 1
if force_one_process_per_chain is False:
if not cmdstan_version_before(2, 28, info_dict):
one_process_per_chain = False
num_threads = parallel_chains * num_threads
parallel_procs = 1
if stan_threads == 'false':
get_logger().warning(
'Stan program not compiled for threading, '
'process will run chains sequentially. '
'For multi-chain parallelization, recompile '
'the model with argument '
'"cpp_options={\'STAN_THREADS\':\'TRUE\'}.'
)
else:
one_process_per_chain = False
num_threads = parallel_chains * num_threads
parallel_procs = 1
if stan_threads == 'false':
get_logger().warning(
'Installed version of CmdStan cannot multi-process '
'chains, will run %d processes. '
'Run "install_cmdstan" to upgrade to latest version.',
chains,
'Stan program not compiled for threading, '
'process will run chains sequentially. '
'For multi-chain parallelization, recompile '
'the model with argument '
'"cpp_options={\'STAN_THREADS\':\'TRUE\'}.'
)

os.environ['STAN_NUM_THREADS'] = str(num_threads)

if chain_ids is None:
Expand Down Expand Up @@ -958,15 +930,15 @@ def sample(
def generate_quantities(
self,
data: Mapping[str, Any] | str | os.PathLike | None = None,
previous_fit: Fit | list[str] | None = None,
previous_fit: PrevFit | list[str] | None = None,
seed: int | None = None,
gq_output_dir: OptionalPath = None,
sig_figs: int | None = None,
show_console: bool = False,
refresh: int | None = None,
time_fmt: str = "%Y%m%d%H%M%S",
timeout: float | None = None,
) -> CmdStanGQ[Fit]:
) -> CmdStanGQ[PrevFit]:
"""
Run CmdStan's generate_quantities method which runs the generated
quantities block of a model given an existing sample.
Expand Down Expand Up @@ -1032,7 +1004,16 @@ def generate_quantities(
:return: CmdStanGQ object
"""

if isinstance(previous_fit, (CmdStanMCMC, CmdStanMLE, CmdStanVB)):
if isinstance(
previous_fit,
(
CmdStanMCMC,
CmdStanMLE,
CmdStanVB,
CmdStanLaplace,
CmdStanPathfinder,
),
):
fit_object = previous_fit
fit_csv_files = previous_fit.runset.csv_files
elif isinstance(previous_fit, list):
Expand All @@ -1042,7 +1023,7 @@ def generate_quantities(
)
try:
fit_csv_files = previous_fit
fit_object: Fit = from_csv(fit_csv_files) # type: ignore
fit_object: PrevFit = from_csv(fit_csv_files) # type: ignore
except ValueError as e:
raise ValueError(
'Invalid sample from Stan CSV files, error:\n\t{}\n\t'
Expand All @@ -1064,11 +1045,6 @@ def generate_quantities(
'to generate additional quantities of interest.'
)
elif isinstance(fit_object, CmdStanMLE):
if cmdstan_version_before(2, 31):
raise RuntimeError(
"Method generate_quantities was not "
"available for non-HMC until CmdStan 2.31"
)
chains = 1
chain_ids = [1]
if fit_object._save_iterations:
Expand All @@ -1077,11 +1053,6 @@ def generate_quantities(
'to generate additional quantities of interest.'
)
else: # isinstance(fit_object, CmdStanVB)
if cmdstan_version_before(2, 31):
raise RuntimeError(
"Method generate_quantities was not "
"available for non-HMC until CmdStan 2.31"
)
chains = 1
chain_ids = [1]

Expand Down Expand Up @@ -1492,19 +1463,6 @@ def pathfinder(
"""

exe_info = self.exe_info()
if cmdstan_version_before(2, 33, exe_info):
raise ValueError(
"Method 'pathfinder' not available for CmdStan versions "
"before 2.33"
)

if (not psis_resample or not calculate_lp) and cmdstan_version_before(
2, 34, exe_info
):
raise ValueError(
"Arguments 'psis_resample' and 'calculate_lp' are only "
"available for CmdStan versions 2.34 and later"
)

if num_threads is not None:
if (
Expand Down Expand Up @@ -1613,11 +1571,6 @@ def log_prob(
unconstrained parameters of the model.
"""

if cmdstan_version_before(2, 31, self.exe_info()):
raise ValueError(
"Method 'log_prob' not available for CmdStan versions "
"before 2.31"
)
with (
temp_single_json(data) as _data,
temp_single_json(params) as _params,
Expand Down Expand Up @@ -1729,11 +1682,7 @@ def laplace_sample(

:return: A :class:`CmdStanLaplace` object.
"""
if cmdstan_version_before(2, 32, self.exe_info()):
raise ValueError(
"Method 'laplace_sample' not available for CmdStan versions "
"before 2.32"
)

if opt_args is not None and mode is not None:
raise ValueError(
"Cannot specify both 'opt_args' and 'mode' arguments"
Expand Down
3 changes: 2 additions & 1 deletion cmdstanpy/stanfit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from cmdstanpy.utils import check_sampler_csv, get_logger, stancsv

from .gq import CmdStanGQ
from .gq import CmdStanGQ, PrevFit
from .laplace import CmdStanLaplace
from .mcmc import CmdStanMCMC
from .metadata import InferenceMetadata
Expand All @@ -31,6 +31,7 @@
"CmdStanGQ",
"CmdStanLaplace",
"CmdStanPathfinder",
"PrevFit",
]


Expand Down
Loading