Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ jobs:
run: python -m pip freeze

- name: Run flake8, pylint, mypy
if: matrix.python-version == '3.11'
if: matrix.python-version == '3.14'
run: |
flake8 cmdstanpy test
pylint -v cmdstanpy test
mypy cmdstanpy
mypy cmdstanpy test

- name: CmdStan installation cacheing
id: cache-cmdstan
Expand Down
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ repos:
rev: v1.5.0
hooks:
- id: mypy
exclude: ^test/
additional_dependencies: [ numpy >= 1.22]
# local uses the user-installed pylint, this allows dependency checking
- repo: local
Expand Down
2 changes: 1 addition & 1 deletion cmdstanpy/cmdstan_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ def __init__(
PathfinderArgs,
],
data: Union[Mapping[str, Any], str, None] = None,
seed: Union[int, list[int], None] = None,
seed: Union[int, np.integer, list[int], list[np.integer], None] = None,
inits: Union[int, float, str, list[str], None] = None,
output_dir: OptionalPath = None,
sig_figs: Optional[int] = None,
Expand Down
15 changes: 2 additions & 13 deletions cmdstanpy/install_cmdstan.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@
import urllib.error
import urllib.request
from collections import OrderedDict
from functools import cached_property
from pathlib import Path
from time import sleep
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from typing import Any, Callable, Optional, Union

from tqdm.auto import tqdm

Expand All @@ -47,18 +48,6 @@

from . import progress as progbar

if sys.version_info >= (3, 8) or TYPE_CHECKING:
# mypy only knows about the new built-in cached_property
from functools import cached_property
else:
# on older Python versions, this is the recommended
# way to get the same effect
from functools import lru_cache

def cached_property(fun):
return property(lru_cache(maxsize=None)(fun))


try:
# on MacOS and Linux, importing this
# improves the UX of the input() function
Expand Down
15 changes: 10 additions & 5 deletions cmdstanpy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from concurrent.futures import ThreadPoolExecutor
from io import StringIO
from multiprocessing import cpu_count
from typing import Any, Callable, Mapping, Optional, TypeVar, Union
from typing import Any, Callable, Mapping, Optional, Sequence, TypeVar, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -461,8 +461,7 @@ def sample(
Mapping[str, Any],
float,
str,
list[str],
list[Mapping[str, Any]],
Sequence[Union[str, Mapping[str, Any]]],
None,
] = None,
iter_warmup: Optional[int] = None,
Expand Down Expand Up @@ -493,7 +492,7 @@ def sample(
str,
np.ndarray,
Mapping[str, Any],
list[Union[str, np.ndarray, Mapping[str, Any]]],
Sequence[Union[str, np.ndarray, Mapping[str, Any]]],
None,
] = None,
) -> CmdStanMCMC:
Expand Down Expand Up @@ -1360,7 +1359,13 @@ def pathfinder(
calculate_lp: bool = True,
# arguments standard to all methods
seed: Optional[int] = None,
inits: Union[dict[str, float], float, str, os.PathLike, None] = None,
inits: Union[
Mapping[str, Any],
float,
str,
Sequence[Union[str, Mapping[str, Any]]],
None,
] = None,
output_dir: OptionalPath = None,
sig_figs: Optional[int] = None,
save_profile: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion cmdstanpy/stanfit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def from_csv(
mode: CmdStanMLE = from_csv(
config_dict['mode'], # type: ignore
method='optimize',
) # type: ignore
)
return CmdStanLaplace(runset, mode=mode)
elif config_dict['method'] == 'pathfinder':
pathfinder_args = PathfinderArgs(
Expand Down
4 changes: 2 additions & 2 deletions cmdstanpy/stanfit/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def _validate_csv_files(self) -> dict[str, Any]:
save_warmup=self._save_warmup,
thin=self._thin,
)
self._chain_time.append(dzero['time']) # type: ignore
self._chain_time.append(dzero['time'])
if not self._is_fixed_param:
self._divergences[i] = dzero['ct_divergences']
self._max_treedepths[i] = dzero['ct_max_treedepth']
Expand All @@ -360,7 +360,7 @@ def _validate_csv_files(self) -> dict[str, Any]:
save_warmup=self._save_warmup,
thin=self._thin,
)
self._chain_time.append(drest['time']) # type: ignore
self._chain_time.append(drest['time'])
for key in dzero:
# check args that matter for parsing, plus name, version
if (
Expand Down
2 changes: 1 addition & 1 deletion cmdstanpy/stanfit/pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def is_resampled(self) -> bool:
Returns True if the draws were resampled from several Pathfinder
approximations, False otherwise.
"""
return ( # type: ignore
return (
self._metadata.cmdstan_config.get("num_paths", 4) > 1
and self._metadata.cmdstan_config.get('psis_resample', 1)
in (1, 'true')
Expand Down
2 changes: 1 addition & 1 deletion cmdstanpy/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def show_versions(output: bool = True) -> str:
deps_info.append((module, None))
else:
try:
ver = mod.__version__ # type: ignore
ver = mod.__version__
deps_info.append((module, ver))
# pylint: disable=broad-except
except Exception:
Expand Down
16 changes: 10 additions & 6 deletions cmdstanpy/utils/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import re
import shutil
import tempfile
from typing import Any, Iterator, Mapping, Optional, Union
from typing import Any, Iterator, Mapping, Optional, Sequence, Union

import numpy as np

Expand Down Expand Up @@ -131,10 +131,12 @@ def _temp_single_json(


def _temp_multiinput(
input: Union[str, os.PathLike, Mapping[str, Any], list[Any], None],
input: Union[str, os.PathLike, Mapping[str, Any], Sequence[Any], None],
base: int = 1,
) -> Iterator[Optional[str]]:
if isinstance(input, list):
if isinstance(input, Sequence) and not isinstance(
input, (str, os.PathLike)
):
# most complicated case: list of inits
# for multiple chains, we need to create multiple files
# which look like somename_{i}.json and then pass somename.json
Expand Down Expand Up @@ -170,7 +172,7 @@ def _temp_multiinput(
@contextlib.contextmanager
def temp_metrics(
metrics: Union[
str, os.PathLike, Mapping[str, Any], np.ndarray, list[Any], None
str, os.PathLike, Mapping[str, Any], np.ndarray, Sequence[Any], None
],
*,
id: int = 1,
Expand Down Expand Up @@ -200,7 +202,7 @@ def temp_metrics(
@contextlib.contextmanager
def temp_inits(
inits: Union[
str, os.PathLike, Mapping[str, Any], float, int, list[Any], None
str, os.PathLike, Mapping[str, Any], float, int, Sequence[Any], None
],
*,
allow_multiple: bool = True,
Expand All @@ -212,7 +214,9 @@ def temp_inits(
if allow_multiple:
yield from _temp_multiinput(inits, base=id)
else:
if isinstance(inits, list):
if isinstance(inits, Sequence) and not isinstance(
inits, (str, os.PathLike)
):
raise ValueError('Expected single initialization, got list')
yield from _temp_single_json(inits)

Expand Down
8 changes: 5 additions & 3 deletions cmdstanpy/utils/stancsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
import re
import warnings
from typing import Any, Iterator, Mapping, Optional, Union
from typing import Any, Iterator, Mapping, Optional, Sequence, Union

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -638,11 +638,13 @@ def try_deduce_metric_type(
str,
np.ndarray,
Mapping[str, Any],
list[Union[str, np.ndarray, Mapping[str, Any]]],
Sequence[Union[str, np.ndarray, Mapping[str, Any]]],
],
) -> Optional[str]:
"""Given a user-supplied metric, try to infer the correct metric type."""
if isinstance(inv_metric, list):
if isinstance(inv_metric, Sequence) and not isinstance(
inv_metric, (str, np.ndarray, Mapping)
):
if inv_metric:
inv_metric = inv_metric[0]

Expand Down
25 changes: 17 additions & 8 deletions test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import platform
import re
from importlib import reload
from typing import Tuple, Type
from types import ModuleType
from typing import Generator, Optional, Type
from unittest import mock

import pytest
Expand All @@ -20,15 +21,17 @@

# pylint: disable=invalid-name
@contextlib.contextmanager
def raises_nested(expected_exception: Type[Exception], match: str) -> None:
def raises_nested(
expected_exception: Type[Exception], match: str
) -> Generator[None, None, None]:
"""A version of assertRaisesRegex that checks the full traceback.

Useful for when an exception is raised from another and you wish to
inspect the inner exception.
"""
with pytest.raises(expected_exception) as ctx:
yield
exception: Exception = ctx.value
exception: Optional[BaseException] = ctx.value
lines = []
while exception:
lines.append(str(exception))
Expand All @@ -38,7 +41,9 @@ def raises_nested(expected_exception: Type[Exception], match: str) -> None:


@contextlib.contextmanager
def without_import(library, module):
def without_import(
library: str, module: ModuleType
) -> Generator[None, None, None]:
with mock.patch.dict('sys.modules', {library: None}):
reload(module)
yield
Expand All @@ -47,7 +52,7 @@ def without_import(library, module):

def check_present(
caplog: pytest.LogCaptureFixture,
*conditions: Tuple,
*conditions: tuple,
clear: bool = True,
) -> None:
"""
Expand All @@ -58,9 +63,13 @@ def check_present(
if isinstance(level, str):
level = getattr(logging, level)
found = any(
logger == logger_ and level == level_ and message.match(message_)
if isinstance(message, re.Pattern)
else message == message_
(
logger == logger_
and level == level_
and message.match(message_)
if isinstance(message, re.Pattern)
else message == message_
)
for logger_, level_, message_ in caplog.record_tuples
)
if not found:
Expand Down
3 changes: 2 additions & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""The global configuration for the test suite"""
import os
import subprocess
from typing import Generator

import pytest

Expand All @@ -9,7 +10,7 @@


@pytest.fixture(scope='session', autouse=True)
def cleanup_test_files():
def cleanup_test_files() -> Generator[None, None, None]:
"""Remove compiled models and output files after test run."""
yield
subprocess.Popen(
Expand Down
33 changes: 0 additions & 33 deletions test/example_script.py

This file was deleted.

Loading