Skip to content

Commit c3905f2

Browse files
committed
Merge branch 'develop' into feature/log-prob
2 parents 64fff58 + d5ed366 commit c3905f2

18 files changed

+5666
-5742
lines changed

cmdstanpy/model.py

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,14 @@ def __init__(
176176
if not cmdstan_version_before(
177177
2, 27
178178
): # unknown end of version range
179-
model_info = self.src_info()
180-
if 'parameters' in model_info:
181-
self._fixed_param |= len(model_info['parameters']) == 0
179+
try:
180+
model_info = self.src_info()
181+
if 'parameters' in model_info:
182+
self._fixed_param |= len(model_info['parameters']) == 0
183+
except ValueError as e:
184+
if compile:
185+
raise
186+
get_logger().debug(e)
182187

183188
if exe_file is not None:
184189
self._exe_file = os.path.realpath(os.path.expanduser(exe_file))
@@ -276,32 +281,22 @@ def src_info(self) -> Dict[str, Any]:
276281
If stanc is older than 2.27 or if the stan
277282
file cannot be found, returns an empty dictionary.
278283
"""
279-
result: Dict[str, Any] = {}
280-
if self.stan_file is None:
281-
return result
282-
try:
283-
cmd = (
284-
[os.path.join(cmdstan_path(), 'bin', 'stanc' + EXTENSION)]
285-
# handle include-paths, allow-undefined etc
286-
+ self._compiler_options.compose_stanc()
287-
+ [
288-
'--info',
289-
str(self.stan_file),
290-
]
291-
)
292-
proc = subprocess.run(
293-
cmd, capture_output=True, text=True, check=True
284+
if self.stan_file is None or cmdstan_version_before(2, 27):
285+
return {}
286+
cmd = (
287+
[os.path.join(cmdstan_path(), 'bin', 'stanc' + EXTENSION)]
288+
# handle include-paths, allow-undefined etc
289+
+ self._compiler_options.compose_stanc()
290+
+ ['--info', str(self.stan_file)]
291+
)
292+
proc = subprocess.run(cmd, capture_output=True, text=True, check=False)
293+
if proc.returncode:
294+
raise ValueError(
295+
f"Failed to get source info for Stan model "
296+
f"'{self._stan_file}'. Console:\n{proc.stderr}"
294297
)
295-
result = json.loads(proc.stdout)
296-
return result
297-
except (
298-
ValueError,
299-
RuntimeError,
300-
OSError,
301-
subprocess.CalledProcessError,
302-
) as e:
303-
get_logger().debug(e)
304-
return result
298+
result: Dict[str, Any] = json.loads(proc.stdout)
299+
return result
305300

306301
def format(
307302
self,

requirements-test.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,5 @@ pytest
44
pytest-cov
55
pytest-order
66
mypy
7-
testfixtures
87
tqdm
98
xarray

test/__init__.py

Lines changed: 55 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,69 @@
11
"""Testing utilities for CmdStanPy."""
22

33
import contextlib
4-
import os
5-
import sys
6-
import unittest
4+
import logging
5+
import platform
6+
import re
7+
from typing import List, Type
8+
from unittest import mock
79
from importlib import reload
8-
from io import StringIO
10+
import pytest
911

1012

11-
class CustomTestCase(unittest.TestCase):
12-
# pylint: disable=invalid-name
13-
@contextlib.contextmanager
14-
def assertRaisesRegexNested(self, exc, msg):
15-
"""A version of assertRaisesRegex that checks the full traceback.
13+
mark_windows_only = pytest.mark.skipif(
14+
platform.system() != 'Windows', reason='only runs on windows'
15+
)
16+
mark_not_windows = pytest.mark.skipif(
17+
platform.system() == 'Windows', reason='does not run on windows'
18+
)
1619

17-
Useful for when an exception is raised from another and you wish to
18-
inspect the inner exception.
19-
"""
20-
with self.assertRaises(exc) as ctx:
21-
yield
22-
exception = ctx.exception
23-
exn_string = str(ctx.exception)
24-
while exception.__cause__ is not None:
25-
exception = exception.__cause__
26-
exn_string += "\n" + str(exception)
27-
self.assertRegex(exn_string, msg)
2820

29-
@contextlib.contextmanager
30-
def without_import(self, library, module):
31-
with unittest.mock.patch.dict('sys.modules', {library: None}):
32-
reload(module)
33-
yield
34-
reload(module)
21+
# pylint: disable=invalid-name
22+
@contextlib.contextmanager
23+
def raises_nested(expected_exception: Type[Exception], match: str) -> None:
24+
"""A version of assertRaisesRegex that checks the full traceback.
3525
36-
# recipe modified from https://stackoverflow.com/a/36491341
37-
@contextlib.contextmanager
38-
def replace_stdin(self, target: str):
39-
orig = sys.stdin
40-
sys.stdin = StringIO(target)
26+
Useful for when an exception is raised from another and you wish to
27+
inspect the inner exception.
28+
"""
29+
with pytest.raises(expected_exception) as ctx:
4130
yield
42-
sys.stdin = orig
43-
44-
# recipe from https://stackoverflow.com/a/34333710
45-
@contextlib.contextmanager
46-
def modified_environ(self, *remove, **update):
47-
"""
48-
Temporarily updates the ``os.environ`` dictionary in-place.
49-
50-
The ``os.environ`` dictionary is updated in-place so that
51-
the modification is sure to work in all situations.
31+
exception: Exception = ctx.value
32+
lines = []
33+
while exception:
34+
lines.append(str(exception))
35+
exception = exception.__cause__
36+
text = "\n".join(lines)
37+
assert re.search(match, text), f"pattern `{match}` does not match `{text}`"
5238

53-
:param remove: Environment variables to remove.
54-
:param update: Dictionary of environment variables and values to
55-
add/update.
56-
"""
57-
env = os.environ
58-
update = update or {}
59-
remove = remove or []
6039

61-
# List of environment variables being updated or removed.
62-
stomped = (set(update.keys()) | set(remove)) & set(env.keys())
63-
# Environment variables and values to restore on exit.
64-
update_after = {k: env[k] for k in stomped}
65-
# Environment variables and values to remove on exit.
66-
remove_after = frozenset(k for k in update if k not in env)
40+
@contextlib.contextmanager
41+
def without_import(library, module):
42+
with mock.patch.dict('sys.modules', {library: None}):
43+
reload(module)
44+
yield
45+
reload(module)
6746

68-
try:
69-
env.update(update)
70-
for k in remove:
71-
env.pop(k, None)
72-
yield
73-
finally:
74-
env.update(update_after)
75-
for k in remove_after:
76-
env.pop(k)
7747

78-
# pylint: disable=invalid-name
79-
def assertPathsEqual(self, path1, path2):
80-
"""Assert paths are equal after normalization"""
81-
self.assertTrue(os.path.samefile(path1, path2))
48+
def check_present(
49+
caplog: pytest.LogCaptureFixture,
50+
*conditions: List[tuple],
51+
clear: bool = True,
52+
) -> None:
53+
"""
54+
Check that all desired records exist.
55+
"""
56+
for condition in conditions:
57+
logger, level, message = condition
58+
if isinstance(level, str):
59+
level = getattr(logging, level)
60+
found = any(
61+
logger == logger_ and level == level_ and message.match(message_)
62+
if isinstance(message, re.Pattern)
63+
else message == message_
64+
for logger_, level_, message_ in caplog.record_tuples
65+
)
66+
if not found:
67+
raise ValueError(f"logs did not contain the record {condition}")
68+
if clear:
69+
caplog.clear()

test/conftest.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@
88
DATAFILES_PATH = os.path.join(HERE, 'data')
99

1010

11-
# after we have run all tests, use git to delete the built files in data/
12-
13-
1411
@pytest.fixture(scope='session', autouse=True)
1512
def cleanup_test_files():
1613
"""Remove compiled models and output files after test run."""

test/data/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,6 @@
99
*.testbak
1010
*.bak-*
1111
!return_one.hpp
12+
# Ignore temporary files created as part of compilation.
13+
*.o
14+
*.o.tmp

0 commit comments

Comments
 (0)