|
1 | 1 | """Testing utilities for CmdStanPy.""" |
2 | 2 |
|
3 | 3 | import contextlib |
4 | | -import os |
5 | | -import sys |
6 | | -import unittest |
| 4 | +import logging |
| 5 | +import platform |
| 6 | +import re |
| 7 | +from typing import List, Type |
| 8 | +from unittest import mock |
7 | 9 | from importlib import reload |
8 | | -from io import StringIO |
| 10 | +import pytest |
9 | 11 |
|
10 | 12 |
|
11 | | -class CustomTestCase(unittest.TestCase): |
12 | | - # pylint: disable=invalid-name |
13 | | - @contextlib.contextmanager |
14 | | - def assertRaisesRegexNested(self, exc, msg): |
15 | | - """A version of assertRaisesRegex that checks the full traceback. |
| 13 | +mark_windows_only = pytest.mark.skipif( |
| 14 | + platform.system() != 'Windows', reason='only runs on windows' |
| 15 | +) |
| 16 | +mark_not_windows = pytest.mark.skipif( |
| 17 | + platform.system() == 'Windows', reason='does not run on windows' |
| 18 | +) |
16 | 19 |
|
17 | | - Useful for when an exception is raised from another and you wish to |
18 | | - inspect the inner exception. |
19 | | - """ |
20 | | - with self.assertRaises(exc) as ctx: |
21 | | - yield |
22 | | - exception = ctx.exception |
23 | | - exn_string = str(ctx.exception) |
24 | | - while exception.__cause__ is not None: |
25 | | - exception = exception.__cause__ |
26 | | - exn_string += "\n" + str(exception) |
27 | | - self.assertRegex(exn_string, msg) |
28 | 20 |
|
29 | | - @contextlib.contextmanager |
30 | | - def without_import(self, library, module): |
31 | | - with unittest.mock.patch.dict('sys.modules', {library: None}): |
32 | | - reload(module) |
33 | | - yield |
34 | | - reload(module) |
| 21 | +# pylint: disable=invalid-name |
| 22 | +@contextlib.contextmanager |
| 23 | +def raises_nested(expected_exception: Type[Exception], match: str) -> None: |
| 24 | + """A version of assertRaisesRegex that checks the full traceback. |
35 | 25 |
|
36 | | - # recipe modified from https://stackoverflow.com/a/36491341 |
37 | | - @contextlib.contextmanager |
38 | | - def replace_stdin(self, target: str): |
39 | | - orig = sys.stdin |
40 | | - sys.stdin = StringIO(target) |
| 26 | + Useful for when an exception is raised from another and you wish to |
| 27 | + inspect the inner exception. |
| 28 | + """ |
| 29 | + with pytest.raises(expected_exception) as ctx: |
41 | 30 | yield |
42 | | - sys.stdin = orig |
43 | | - |
44 | | - # recipe from https://stackoverflow.com/a/34333710 |
45 | | - @contextlib.contextmanager |
46 | | - def modified_environ(self, *remove, **update): |
47 | | - """ |
48 | | - Temporarily updates the ``os.environ`` dictionary in-place. |
49 | | -
|
50 | | - The ``os.environ`` dictionary is updated in-place so that |
51 | | - the modification is sure to work in all situations. |
| 31 | + exception: Exception = ctx.value |
| 32 | + lines = [] |
| 33 | + while exception: |
| 34 | + lines.append(str(exception)) |
| 35 | + exception = exception.__cause__ |
| 36 | + text = "\n".join(lines) |
| 37 | + assert re.search(match, text), f"pattern `{match}` does not match `{text}`" |
52 | 38 |
|
53 | | - :param remove: Environment variables to remove. |
54 | | - :param update: Dictionary of environment variables and values to |
55 | | - add/update. |
56 | | - """ |
57 | | - env = os.environ |
58 | | - update = update or {} |
59 | | - remove = remove or [] |
60 | 39 |
|
61 | | - # List of environment variables being updated or removed. |
62 | | - stomped = (set(update.keys()) | set(remove)) & set(env.keys()) |
63 | | - # Environment variables and values to restore on exit. |
64 | | - update_after = {k: env[k] for k in stomped} |
65 | | - # Environment variables and values to remove on exit. |
66 | | - remove_after = frozenset(k for k in update if k not in env) |
| 40 | +@contextlib.contextmanager |
| 41 | +def without_import(library, module): |
| 42 | + with mock.patch.dict('sys.modules', {library: None}): |
| 43 | + reload(module) |
| 44 | + yield |
| 45 | + reload(module) |
67 | 46 |
|
68 | | - try: |
69 | | - env.update(update) |
70 | | - for k in remove: |
71 | | - env.pop(k, None) |
72 | | - yield |
73 | | - finally: |
74 | | - env.update(update_after) |
75 | | - for k in remove_after: |
76 | | - env.pop(k) |
77 | 47 |
|
78 | | - # pylint: disable=invalid-name |
79 | | - def assertPathsEqual(self, path1, path2): |
80 | | - """Assert paths are equal after normalization""" |
81 | | - self.assertTrue(os.path.samefile(path1, path2)) |
| 48 | +def check_present( |
| 49 | + caplog: pytest.LogCaptureFixture, |
| 50 | + *conditions: List[tuple], |
| 51 | + clear: bool = True, |
| 52 | +) -> None: |
| 53 | + """ |
| 54 | + Check that all desired records exist. |
| 55 | + """ |
| 56 | + for condition in conditions: |
| 57 | + logger, level, message = condition |
| 58 | + if isinstance(level, str): |
| 59 | + level = getattr(logging, level) |
| 60 | + found = any( |
| 61 | + logger == logger_ and level == level_ and message.match(message_) |
| 62 | + if isinstance(message, re.Pattern) |
| 63 | + else message == message_ |
| 64 | + for logger_, level_, message_ in caplog.record_tuples |
| 65 | + ) |
| 66 | + if not found: |
| 67 | + raise ValueError(f"logs did not contain the record {condition}") |
| 68 | + if clear: |
| 69 | + caplog.clear() |
0 commit comments