Skip to content

Commit fd691c8

Browse files
authored
Merge pull request #826 from stan-dev/mypy-tests
Fix typing issues in tests
2 parents 89dcd4c + 976ac85 commit fd691c8

28 files changed

+276
-233
lines changed

.github/workflows/main.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,11 @@ jobs:
6868
run: python -m pip freeze
6969

7070
- name: Run flake8, pylint, mypy
71-
if: matrix.python-version == '3.11'
71+
if: matrix.python-version == '3.14'
7272
run: |
7373
flake8 cmdstanpy test
7474
pylint -v cmdstanpy test
75-
mypy cmdstanpy
75+
mypy cmdstanpy test
7676
7777
- name: CmdStan installation cacheing
7878
id: cache-cmdstan

.pre-commit-config.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ repos:
2323
rev: v1.5.0
2424
hooks:
2525
- id: mypy
26-
exclude: ^test/
2726
additional_dependencies: [ numpy >= 1.22]
2827
# local uses the user-installed pylint, this allows dependency checking
2928
- repo: local

cmdstanpy/cmdstan_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,7 @@ def __init__(
643643
PathfinderArgs,
644644
],
645645
data: Union[Mapping[str, Any], str, None] = None,
646-
seed: Union[int, list[int], None] = None,
646+
seed: Union[int, np.integer, list[int], list[np.integer], None] = None,
647647
inits: Union[int, float, str, list[str], None] = None,
648648
output_dir: OptionalPath = None,
649649
sig_figs: Optional[int] = None,

cmdstanpy/install_cmdstan.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@
2929
import urllib.error
3030
import urllib.request
3131
from collections import OrderedDict
32+
from functools import cached_property
3233
from pathlib import Path
3334
from time import sleep
34-
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
35+
from typing import Any, Callable, Optional, Union
3536

3637
from tqdm.auto import tqdm
3738

@@ -47,18 +48,6 @@
4748

4849
from . import progress as progbar
4950

50-
if sys.version_info >= (3, 8) or TYPE_CHECKING:
51-
# mypy only knows about the new built-in cached_property
52-
from functools import cached_property
53-
else:
54-
# on older Python versions, this is the recommended
55-
# way to get the same effect
56-
from functools import lru_cache
57-
58-
def cached_property(fun):
59-
return property(lru_cache(maxsize=None)(fun))
60-
61-
6251
try:
6352
# on MacOS and Linux, importing this
6453
# improves the UX of the input() function

cmdstanpy/model.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from concurrent.futures import ThreadPoolExecutor
1414
from io import StringIO
1515
from multiprocessing import cpu_count
16-
from typing import Any, Callable, Mapping, Optional, TypeVar, Union
16+
from typing import Any, Callable, Mapping, Optional, Sequence, TypeVar, Union
1717

1818
import numpy as np
1919
import pandas as pd
@@ -461,8 +461,7 @@ def sample(
461461
Mapping[str, Any],
462462
float,
463463
str,
464-
list[str],
465-
list[Mapping[str, Any]],
464+
Sequence[Union[str, Mapping[str, Any]]],
466465
None,
467466
] = None,
468467
iter_warmup: Optional[int] = None,
@@ -493,7 +492,7 @@ def sample(
493492
str,
494493
np.ndarray,
495494
Mapping[str, Any],
496-
list[Union[str, np.ndarray, Mapping[str, Any]]],
495+
Sequence[Union[str, np.ndarray, Mapping[str, Any]]],
497496
None,
498497
] = None,
499498
) -> CmdStanMCMC:
@@ -1360,7 +1359,13 @@ def pathfinder(
13601359
calculate_lp: bool = True,
13611360
# arguments standard to all methods
13621361
seed: Optional[int] = None,
1363-
inits: Union[dict[str, float], float, str, os.PathLike, None] = None,
1362+
inits: Union[
1363+
Mapping[str, Any],
1364+
float,
1365+
str,
1366+
Sequence[Union[str, Mapping[str, Any]]],
1367+
None,
1368+
] = None,
13641369
output_dir: OptionalPath = None,
13651370
sig_figs: Optional[int] = None,
13661371
save_profile: bool = False,

cmdstanpy/stanfit/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def from_csv(
248248
mode: CmdStanMLE = from_csv(
249249
config_dict['mode'], # type: ignore
250250
method='optimize',
251-
) # type: ignore
251+
)
252252
return CmdStanLaplace(runset, mode=mode)
253253
elif config_dict['method'] == 'pathfinder':
254254
pathfinder_args = PathfinderArgs(

cmdstanpy/stanfit/mcmc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def _validate_csv_files(self) -> dict[str, Any]:
348348
save_warmup=self._save_warmup,
349349
thin=self._thin,
350350
)
351-
self._chain_time.append(dzero['time']) # type: ignore
351+
self._chain_time.append(dzero['time'])
352352
if not self._is_fixed_param:
353353
self._divergences[i] = dzero['ct_divergences']
354354
self._max_treedepths[i] = dzero['ct_max_treedepth']
@@ -360,7 +360,7 @@ def _validate_csv_files(self) -> dict[str, Any]:
360360
save_warmup=self._save_warmup,
361361
thin=self._thin,
362362
)
363-
self._chain_time.append(drest['time']) # type: ignore
363+
self._chain_time.append(drest['time'])
364364
for key in dzero:
365365
# check args that matter for parsing, plus name, version
366366
if (

cmdstanpy/stanfit/pathfinder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def is_resampled(self) -> bool:
208208
Returns True if the draws were resampled from several Pathfinder
209209
approximations, False otherwise.
210210
"""
211-
return ( # type: ignore
211+
return (
212212
self._metadata.cmdstan_config.get("num_paths", 4) > 1
213213
and self._metadata.cmdstan_config.get('psis_resample', 1)
214214
in (1, 'true')

cmdstanpy/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def show_versions(output: bool = True) -> str:
8080
deps_info.append((module, None))
8181
else:
8282
try:
83-
ver = mod.__version__ # type: ignore
83+
ver = mod.__version__
8484
deps_info.append((module, ver))
8585
# pylint: disable=broad-except
8686
except Exception:

cmdstanpy/utils/filesystem.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import re
99
import shutil
1010
import tempfile
11-
from typing import Any, Iterator, Mapping, Optional, Union
11+
from typing import Any, Iterator, Mapping, Optional, Sequence, Union
1212

1313
import numpy as np
1414

@@ -131,10 +131,12 @@ def _temp_single_json(
131131

132132

133133
def _temp_multiinput(
134-
input: Union[str, os.PathLike, Mapping[str, Any], list[Any], None],
134+
input: Union[str, os.PathLike, Mapping[str, Any], Sequence[Any], None],
135135
base: int = 1,
136136
) -> Iterator[Optional[str]]:
137-
if isinstance(input, list):
137+
if isinstance(input, Sequence) and not isinstance(
138+
input, (str, os.PathLike)
139+
):
138140
# most complicated case: list of inits
139141
# for multiple chains, we need to create multiple files
140142
# which look like somename_{i}.json and then pass somename.json
@@ -170,7 +172,7 @@ def _temp_multiinput(
170172
@contextlib.contextmanager
171173
def temp_metrics(
172174
metrics: Union[
173-
str, os.PathLike, Mapping[str, Any], np.ndarray, list[Any], None
175+
str, os.PathLike, Mapping[str, Any], np.ndarray, Sequence[Any], None
174176
],
175177
*,
176178
id: int = 1,
@@ -200,7 +202,7 @@ def temp_metrics(
200202
@contextlib.contextmanager
201203
def temp_inits(
202204
inits: Union[
203-
str, os.PathLike, Mapping[str, Any], float, int, list[Any], None
205+
str, os.PathLike, Mapping[str, Any], float, int, Sequence[Any], None
204206
],
205207
*,
206208
allow_multiple: bool = True,
@@ -212,7 +214,9 @@ def temp_inits(
212214
if allow_multiple:
213215
yield from _temp_multiinput(inits, base=id)
214216
else:
215-
if isinstance(inits, list):
217+
if isinstance(inits, Sequence) and not isinstance(
218+
inits, (str, os.PathLike)
219+
):
216220
raise ValueError('Expected single initialization, got list')
217221
yield from _temp_single_json(inits)
218222

0 commit comments

Comments
 (0)