Skip to content

Commit 3d4edcd

Browse files
committed
Convert List/Tuple/Dict typing to generics
1 parent c9fa8ce commit 3d4edcd

20 files changed

+373
-604
lines changed

cmdstanpy/cmdstan_args.py

Lines changed: 53 additions & 82 deletions
Large diffs are not rendered by default.

cmdstanpy/compilation.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from copy import copy
1212
from datetime import datetime
1313
from pathlib import Path
14-
from typing import Any, Dict, Iterable, List, Optional, Union
14+
from typing import Any, Iterable, Optional, Union
1515

1616
from cmdstanpy.utils import get_logger
1717
from cmdstanpy.utils.cmdstan import (
@@ -81,8 +81,8 @@ class CompilerOptions:
8181
def __init__(
8282
self,
8383
*,
84-
stanc_options: Optional[Dict[str, Any]] = None,
85-
cpp_options: Optional[Dict[str, Any]] = None,
84+
stanc_options: Optional[dict[str, Any]] = None,
85+
cpp_options: Optional[dict[str, Any]] = None,
8686
user_header: OptionalPath = None,
8787
) -> None:
8888
"""Initialize object."""
@@ -116,12 +116,12 @@ def is_empty(self) -> bool:
116116
)
117117

118118
@property
119-
def stanc_options(self) -> Dict[str, Union[bool, int, str, Iterable[str]]]:
119+
def stanc_options(self) -> dict[str, Union[bool, int, str, Iterable[str]]]:
120120
"""Stanc compiler options."""
121121
return self._stanc_options
122122

123123
@property
124-
def cpp_options(self) -> Dict[str, Union[bool, int]]:
124+
def cpp_options(self) -> dict[str, Union[bool, int]]:
125125
"""C++ compiler options."""
126126
return self._cpp_options
127127

@@ -165,8 +165,7 @@ def validate_stanc_opts(self) -> None:
165165
del self._stanc_options[deprecated]
166166
else:
167167
get_logger().warning(
168-
'compiler option "%s" is deprecated and '
169-
'should not be used',
168+
'compiler option "%s" is deprecated and should not be used',
170169
deprecated,
171170
)
172171
for key, val in self._stanc_options.items():
@@ -225,8 +224,7 @@ def validate_cpp_opts(self) -> None:
225224
val = self._cpp_options[key]
226225
if not isinstance(val, int) or val < 0:
227226
raise ValueError(
228-
f'{key} must be a non-negative integer value,'
229-
f' found {val}.'
227+
f'{key} must be a non-negative integer value, found {val}.'
230228
)
231229

232230
def validate_user_header(self) -> None:
@@ -236,8 +234,7 @@ def validate_user_header(self) -> None:
236234
"""
237235
if self._user_header != "":
238236
if not (
239-
os.path.exists(self._user_header)
240-
and os.path.isfile(self._user_header)
237+
os.path.exists(self._user_header) and os.path.isfile(self._user_header)
241238
):
242239
raise ValueError(
243240
f"User header file {self._user_header} cannot be found"
@@ -275,9 +272,7 @@ def add(self, new_opts: "CompilerOptions") -> None: # noqa: disable=Q000
275272
else:
276273
for key, val in new_opts.stanc_options.items():
277274
if key == 'include-paths':
278-
if isinstance(val, Iterable) and not isinstance(
279-
val, str
280-
):
275+
if isinstance(val, Iterable) and not isinstance(val, str):
281276
for path in val:
282277
self.add_include_path(str(path))
283278
else:
@@ -298,7 +293,7 @@ def add_include_path(self, path: str) -> None:
298293
elif path not in self._stanc_options['include-paths']:
299294
self._stanc_options['include-paths'].append(path)
300295

301-
def compose_stanc(self, filename_in_msg: Optional[str]) -> List[str]:
296+
def compose_stanc(self, filename_in_msg: Optional[str]) -> list[str]:
302297
opts = []
303298

304299
if filename_in_msg is not None:
@@ -322,7 +317,7 @@ def compose_stanc(self, filename_in_msg: Optional[str]) -> List[str]:
322317
opts.append(f'--{key}')
323318
return opts
324319

325-
def compose(self, filename_in_msg: Optional[str] = None) -> List[str]:
320+
def compose(self, filename_in_msg: Optional[str] = None) -> list[str]:
326321
"""
327322
Format makefile options as list of strings.
328323
@@ -342,9 +337,7 @@ def compose(self, filename_in_msg: Optional[str] = None) -> List[str]:
342337
return opts
343338

344339

345-
def src_info(
346-
stan_file: str, compiler_options: CompilerOptions
347-
) -> Dict[str, Any]:
340+
def src_info(stan_file: str, compiler_options: CompilerOptions) -> dict[str, Any]:
348341
"""
349342
Get source info for Stan program file.
350343
@@ -363,15 +356,15 @@ def src_info(
363356
f"Failed to get source info for Stan model "
364357
f"'{stan_file}'. Console:\n{proc.stderr}"
365358
)
366-
result: Dict[str, Any] = json.loads(proc.stdout)
359+
result: dict[str, Any] = json.loads(proc.stdout)
367360
return result
368361

369362

370363
def compile_stan_file(
371364
src: Union[str, Path],
372365
force: bool = False,
373-
stanc_options: Optional[Dict[str, Any]] = None,
374-
cpp_options: Optional[Dict[str, Any]] = None,
366+
stanc_options: Optional[dict[str, Any]] = None,
367+
cpp_options: Optional[dict[str, Any]] = None,
375368
user_header: OptionalPath = None,
376369
) -> str:
377370
"""
@@ -480,7 +473,7 @@ def compile_stan_file(
480473
"If the issue persists please open a bug report"
481474
)
482475
raise ValueError(
483-
f"Failed to compile Stan model '{src}'. " f"Console:\n{console}"
476+
f"Failed to compile Stan model '{src}'. Console:\n{console}"
484477
)
485478
return str(exe_target)
486479

@@ -492,7 +485,7 @@ def format_stan_file(
492485
canonicalize: Union[bool, str, Iterable[str]] = False,
493486
max_line_length: int = 78,
494487
backup: bool = True,
495-
stanc_options: Optional[Dict[str, Any]] = None,
488+
stanc_options: Optional[dict[str, Any]] = None,
496489
) -> None:
497490
"""
498491
Run stanc's auto-formatter on the model code. Either saves directly
@@ -532,9 +525,7 @@ def format_stan_file(
532525
else:
533526
raise ValueError(
534527
"Invalid arguments passed for current CmdStan"
535-
+ " version({})\n".format(
536-
cmdstan_version() or "Unknown"
537-
)
528+
+ " version({})\n".format(cmdstan_version() or "Unknown")
538529
+ "--canonicalize requires 2.29 or higher"
539530
)
540531
else:

cmdstanpy/install_cmdstan.py

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
--cores: int, number of cores to use when building, defaults to 1
1818
-c, --compiler : flag, add C++ compiler to path (Windows only)
1919
"""
20+
2021
import argparse
2122
import json
2223
import os
@@ -30,7 +31,7 @@
3031
from collections import OrderedDict
3132
from pathlib import Path
3233
from time import sleep
33-
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
34+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
3435

3536
from tqdm.auto import tqdm
3637

@@ -85,7 +86,7 @@ def is_windows() -> bool:
8586
EXTENSION = '.exe' if is_windows() else ''
8687

8788

88-
def get_headers() -> Dict[str, str]:
89+
def get_headers() -> dict[str, str]:
8990
"""Create headers dictionary."""
9091
headers = {}
9192
GITHUB_PAT = os.environ.get("GITHUB_PAT") # pylint:disable=invalid-name
@@ -109,9 +110,7 @@ def latest_version() -> str:
109110
print('retry ({}/5)'.format(i + 1))
110111
sleep(1)
111112
continue
112-
raise CmdStanRetrieveError(
113-
'Cannot connect to CmdStan github repo.'
114-
) from e
113+
raise CmdStanRetrieveError('Cannot connect to CmdStan github repo.') from e
115114
content = json.loads(response.decode('utf-8'))
116115
tag = content['tag_name']
117116
match = re.search(r'v?(.+)', tag)
@@ -287,26 +286,18 @@ def build(verbose: bool = False, progress: bool = True, cores: int = 1) -> None:
287286
raise CmdStanInstallError(f'Command "make build" failed\n{str(e)}')
288287
if not os.path.exists(os.path.join('bin', 'stansummary' + EXTENSION)):
289288
raise CmdStanInstallError(
290-
f'bin/stansummary{EXTENSION} not found'
291-
', please rebuild or report a bug!'
289+
f'bin/stansummary{EXTENSION} not found, please rebuild or report a bug!'
292290
)
293291
if not os.path.exists(os.path.join('bin', 'diagnose' + EXTENSION)):
294292
raise CmdStanInstallError(
295-
f'bin/stansummary{EXTENSION} not found'
296-
', please rebuild or report a bug!'
293+
f'bin/stansummary{EXTENSION} not found, please rebuild or report a bug!'
297294
)
298295

299296
if is_windows():
300297
# Add tbb to the $PATH on Windows
301-
libtbb = os.path.join(
302-
os.getcwd(), 'stan', 'lib', 'stan_math', 'lib', 'tbb'
303-
)
298+
libtbb = os.path.join(os.getcwd(), 'stan', 'lib', 'stan_math', 'lib', 'tbb')
304299
os.environ['PATH'] = ';'.join(
305-
list(
306-
OrderedDict.fromkeys(
307-
[libtbb] + os.environ.get('PATH', '').split(';')
308-
)
309-
)
300+
list(OrderedDict.fromkeys([libtbb] + os.environ.get('PATH', '').split(';')))
310301
)
311302

312303

@@ -417,8 +408,9 @@ def install_version(
417408
)
418409
if overwrite and os.path.exists('.'):
419410
print(
420-
'Overwrite requested, remove existing build of version '
421-
'{}'.format(cmdstan_version)
411+
'Overwrite requested, remove existing build of version {}'.format(
412+
cmdstan_version
413+
)
422414
)
423415
clean_all(verbose)
424416
print('Rebuilding version {}'.format(cmdstan_version))
@@ -485,9 +477,9 @@ def retrieve_version(version: str, progress: bool = True) -> None:
485477
for i in range(6): # always retry to allow for transient URLErrors
486478
try:
487479
if progress and progbar.allow_show_progress():
488-
progress_hook: Optional[
489-
Callable[[int, int, int], None]
490-
] = wrap_url_progress_hook()
480+
progress_hook: Optional[Callable[[int, int, int], None]] = (
481+
wrap_url_progress_hook()
482+
)
491483
else:
492484
progress_hook = None
493485
file_tmp, _ = urllib.request.urlretrieve(
@@ -496,16 +488,13 @@ def retrieve_version(version: str, progress: bool = True) -> None:
496488
break
497489
except urllib.error.HTTPError as e:
498490
raise CmdStanRetrieveError(
499-
'HTTPError: {}\n'
500-
'Version {} not available from github.com.'.format(
491+
'HTTPError: {}\nVersion {} not available from github.com.'.format(
501492
e.code, version
502493
)
503494
) from e
504495
except urllib.error.URLError as e:
505496
print(
506-
'Failed to download CmdStan version {} from github.com'.format(
507-
version
508-
)
497+
'Failed to download CmdStan version {} from github.com'.format(version)
509498
)
510499
print(e)
511500
if i < 5:
@@ -645,14 +634,13 @@ def run_install(args: Union[InteractiveSettings, InstallationSettings]) -> None:
645634
compile_example(args.verbose)
646635

647636

648-
def parse_cmdline_args() -> Dict[str, Any]:
637+
def parse_cmdline_args() -> dict[str, Any]:
649638
parser = argparse.ArgumentParser("install_cmdstan")
650639
parser.add_argument(
651640
'--interactive',
652641
'-i',
653642
action='store_true',
654-
help="Ignore other arguments and run the installation in "
655-
+ "interactive mode",
643+
help="Ignore other arguments and run the installation in " + "interactive mode",
656644
)
657645
parser.add_argument(
658646
'--version',

cmdstanpy/install_cxx_toolchain.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
-m --no-make : don't install mingw32-make (Windows RTools 4.0 only)
1313
--progress : flag, when specified show progress bar for RTools download
1414
"""
15+
1516
import argparse
1617
import os
1718
import platform
@@ -21,7 +22,7 @@
2122
import urllib.request
2223
from collections import OrderedDict
2324
from time import sleep
24-
from typing import Any, Dict, List
25+
from typing import Any
2526

2627
from cmdstanpy import _DOT_CMDSTAN
2728
from cmdstanpy.utils import pushd, validate_dir, wrap_url_progress_hook
@@ -44,7 +45,7 @@ def usage() -> None:
4445
)
4546

4647

47-
def get_config(dir: str, silent: bool) -> List[str]:
48+
def get_config(dir: str, silent: bool) -> list[str]:
4849
"""Assemble config info."""
4950
config = []
5051
if platform.system() == 'Windows':
@@ -243,7 +244,9 @@ def get_url(version: str) -> str:
243244
if version == '4.0':
244245
# pylint: disable=line-too-long
245246
if IS_64BITS:
246-
url = 'https://cran.r-project.org/bin/windows/Rtools/rtools40-x86_64.exe' # noqa: disable=E501
247+
url = (
248+
'https://cran.r-project.org/bin/windows/Rtools/rtools40-x86_64.exe' # noqa: disable=E501
249+
)
247250
else:
248251
url = 'https://cran.r-project.org/bin/windows/Rtools/rtools40-i686.exe' # noqa: disable=E501
249252
elif version == '3.5':
@@ -260,7 +263,7 @@ def get_toolchain_version(name: str, version: str) -> str:
260263
return toolchain_folder
261264

262265

263-
def run_rtools_install(args: Dict[str, Any]) -> None:
266+
def run_rtools_install(args: dict[str, Any]) -> None:
264267
"""Main."""
265268
if platform.system() not in {'Windows'}:
266269
raise NotImplementedError(
@@ -308,9 +311,7 @@ def run_rtools_install(args: Dict[str, Any]) -> None:
308311
else:
309312
if os.path.exists(toolchain_folder):
310313
shutil.rmtree(toolchain_folder, ignore_errors=False)
311-
retrieve_toolchain(
312-
toolchain_folder + EXTENSION, url, progress=progress
313-
)
314+
retrieve_toolchain(toolchain_folder + EXTENSION, url, progress=progress)
314315
install_version(
315316
toolchain_folder,
316317
toolchain_folder + EXTENSION,
@@ -324,16 +325,14 @@ def run_rtools_install(args: Dict[str, Any]) -> None:
324325
and (version in ('4.0', '4', '40'))
325326
):
326327
if os.path.exists(
327-
os.path.join(
328-
toolchain_folder, 'mingw64', 'bin', 'mingw32-make.exe'
329-
)
328+
os.path.join(toolchain_folder, 'mingw64', 'bin', 'mingw32-make.exe')
330329
):
331330
print('mingw32-make.exe already installed')
332331
else:
333332
install_mingw32_make(toolchain_folder, verbose)
334333

335334

336-
def parse_cmdline_args() -> Dict[str, Any]:
335+
def parse_cmdline_args() -> dict[str, Any]:
337336
parser = argparse.ArgumentParser()
338337
parser.add_argument('--version', '-v', help="version, defaults to latest")
339338
parser.add_argument(

0 commit comments

Comments
 (0)