diff --git a/.gitignore b/.gitignore index c6e6fb8..d3bf6f0 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,4 @@ old-version/ /site/ /site.zip /build/ +.venv diff --git a/devtools/pytest_plugin.py b/devtools/pytest_plugin.py index f80efd3..4d9e58f 100644 --- a/devtools/pytest_plugin.py +++ b/devtools/pytest_plugin.py @@ -2,6 +2,7 @@ import ast import builtins +import contextlib import sys import textwrap from contextvars import ContextVar @@ -15,6 +16,7 @@ import pytest from executing import Source +from typing_extensions import Literal from . import debug @@ -30,11 +32,12 @@ class ToReplace: start_line: int end_line: int | None code: str + instruction_type: Literal['insert_assert', 'insert_pytest_raises'] to_replace: list[ToReplace] = [] -insert_assert_calls: ContextVar[int] = ContextVar('insert_assert_calls', default=0) -insert_assert_summary: ContextVar[list[str]] = ContextVar('insert_assert_summary') +test_replacement_calls: ContextVar[int] = ContextVar('insert_assert_calls', default=0) +test_replacement_summary: ContextVar[list[str]] = ContextVar('insert_assert_summary') def insert_assert(value: Any) -> int: @@ -44,7 +47,7 @@ def insert_assert(value: Any) -> int: format_code = load_black() ex = Source.for_frame(call_frame).executing(call_frame) - if ex.node is None: # pragma: no cover + if ex.node is None: python_code = format_code(str(custom_repr(value))) raise RuntimeError( f'insert_assert() was unable to find the frame from which it was called, called with:\n{python_code}' @@ -58,12 +61,60 @@ def insert_assert(value: Any) -> int: python_code = format_code(f'# insert_assert({arg})\nassert {arg} == {custom_repr(value)}') python_code = textwrap.indent(python_code, ex.node.col_offset * ' ') - to_replace.append(ToReplace(Path(call_frame.f_code.co_filename), ex.node.lineno, ex.node.end_lineno, python_code)) - calls = insert_assert_calls.get() + 1 - insert_assert_calls.set(calls) + to_replace.append( + ToReplace( + Path(call_frame.f_code.co_filename), + ex.node.lineno, + ex.node.end_lineno, + python_code, + 'insert_assert', + ) + ) + calls = test_replacement_calls.get() + 1 + test_replacement_calls.set(calls) return calls +@contextlib.contextmanager +def insert_pytest_raises() -> Generator[None, Any, int]: + # We use frame 2 because frame 1 is the context manager itself + call_frame: FrameType = sys._getframe(2) + if sys.version_info < (3, 8): # pragma: no cover + raise RuntimeError('insert_pytest_raises() requires Python 3.8+') + + format_code = load_black() + ex = Source.for_frame(call_frame).executing(call_frame) + if not ex.statements: + raise RuntimeError('insert_pytest_raises() was unable to find the frame from which it was called') + statement = next(iter(ex.statements)) + if not isinstance(statement, ast.With): + raise RuntimeError("insert_pytest_raises() was called outside of a 'with' statement") + if len(ex.statements) > 1 or len(statement.items) > 1: + raise RuntimeError('insert_pytest_raises() was called alongside other statements, this is not supported') + try: + yield + except Exception as e: + python_code = format_code( + f'# with insert_pytest_raises():\n' + f'with pytest.raises({type(e).__name__}, match=re.escape({repr(str(e))})):\n' + ) + python_code = textwrap.indent(python_code, statement.col_offset * ' ') + to_replace.append( + ToReplace( + Path(call_frame.f_code.co_filename), + statement.lineno, + statement.items[0].context_expr.end_lineno, + python_code, + 'insert_pytest_raises', + ) + ) + calls = test_replacement_calls.get() + 1 + test_replacement_calls.set(calls) + return calls + else: + raise RuntimeError('insert_pytest_raises() was called but no exception was raised') + + def pytest_addoption(parser: Any) -> None: parser.addoption( '--insert-assert-print', @@ -83,6 +134,7 @@ def pytest_addoption(parser: Any) -> None: def insert_assert_add_to_builtins() -> None: try: setattr(builtins, 'insert_assert', insert_assert) + setattr(builtins, 'insert_pytest_raises', insert_pytest_raises) # we also install debug here since the default script doesn't install it setattr(builtins, 'debug', debug) except TypeError: @@ -91,14 +143,16 @@ def insert_assert_add_to_builtins() -> None: @pytest.fixture(autouse=True) -def insert_assert_maybe_fail(pytestconfig: pytest.Config) -> Generator[None, None, None]: - insert_assert_calls.set(0) +def test_replacements_maybe_fail(pytestconfig: pytest.Config) -> Generator[None, None, None]: + test_replacement_calls.set(0) yield print_instead = pytestconfig.getoption('insert_assert_print') if not print_instead: - count = insert_assert_calls.get() + count = test_replacement_calls.get() if count: - pytest.fail(f'devtools-insert-assert: {count} assert{plural(count)} will be inserted', pytrace=False) + pytest.fail( + f'devtools-test-replacement: {count} test replacement{plural(count)} will be inserted', pytrace=False + ) @pytest.fixture(name='insert_assert') @@ -106,8 +160,13 @@ def insert_assert_fixture() -> Callable[[Any], int]: return insert_assert +@pytest.fixture(name='insert_pytest_raises') +def insert_pytest_raises_fixture() -> Callable[[], contextlib._GeneratorContextManager[None]]: + return insert_pytest_raises + + def pytest_report_teststatus(report: pytest.TestReport, config: pytest.Config) -> Any: - if report.when == 'teardown' and report.failed and 'devtools-insert-assert:' in repr(report.longrepr): + if report.when == 'teardown' and report.failed and 'devtools-test-replacement:' in repr(report.longrepr): return 'insert assert', 'i', ('INSERT ASSERT', {'cyan': True}) @@ -156,20 +215,30 @@ def insert_assert_session(pytestconfig: pytest.Config) -> Generator[None, None, file.write_text('\n'.join(lines)) files += 1 prefix = 'Printed' if print_instead else 'Replaced' - summary.append( - f'{prefix} {len(to_replace)} insert_assert() call{plural(to_replace)} in {files} file{plural(files)}' - ) + + insert_assert_count = len([item for item in to_replace if item.instruction_type == 'insert_assert']) + insert_pytest_raises_count = len([item for item in to_replace if item.instruction_type == 'insert_pytest_raises']) + if insert_assert_count: + summary.append( + f'{prefix} {insert_assert_count} insert_assert() call{plural(to_replace)} in {files} file{plural(files)}' + ) + if insert_pytest_raises_count: + summary.append( + f'{prefix} {insert_pytest_raises_count} insert_pytest_raises()' + f' call{plural(to_replace)} in {files} file{plural(files)}' + ) if dup_count: summary.append( - f'\n{dup_count} insert skipped because an assert statement on that line had already be inserted!' + f'\n{dup_count} insert{plural(dup_count)}' + ' skipped because an assert statement on that line had already be inserted!' ) - insert_assert_summary.set(summary) + test_replacement_summary.set(summary) to_replace.clear() def pytest_terminal_summary() -> None: - summary = insert_assert_summary.get(None) + summary = test_replacement_summary.get(None) if summary: print('\n'.join(summary)) diff --git a/tests/test_insert_assert.py b/tests/test_insert_assert.py index 299cee8..f8f35bb 100644 --- a/tests/test_insert_assert.py +++ b/tests/test_insert_assert.py @@ -108,11 +108,9 @@ def test_enum(pytester_pretty, capsys): pytester_pretty.makepyfile( """ from enum import Enum - class Foo(Enum): A = 1 B = 2 - def test_deep(insert_assert): x = Foo.A insert_assert(x) @@ -166,4 +164,22 @@ def test_string_assert(x, insert_assert): ' assert x == 1' ) captured = capsys.readouterr() - assert '2 insert skipped because an assert statement on that line had already be inserted!\n' in captured.out + assert '2 inserts skipped because an assert statement on that line had already be inserted!\n' in captured.out + + +def test_insert_assert_frame_not_found(pytester_pretty, capsys): + os.environ.pop('CI', None) + pytester_pretty.makeconftest(config) + pytester_pretty.makepyfile( + """\ +def test_raise_keyerror(insert_assert): + eval('insert_assert(1)') +""" + ) + result = pytester_pretty.runpytest() + result.assert_outcomes(failed=1) + captured = capsys.readouterr() + assert ( + 'RuntimeError: insert_assert() was unable to find the frame from which it was called, called with:\n' + in captured.out + ) diff --git a/tests/test_insert_pytest_raises.py b/tests/test_insert_pytest_raises.py new file mode 100644 index 0000000..54eaee1 --- /dev/null +++ b/tests/test_insert_pytest_raises.py @@ -0,0 +1,175 @@ +import os +import sys + +import pytest + +pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason='requires Python 3.8+') + + +config = "pytest_plugins = ['devtools.pytest_plugin']" +# language=Python +default_test = """\ +import re, pytest +def test_ok(): + assert 1 + 2 == 3 + +def test_value_error(insert_pytest_raises): + with insert_pytest_raises(): + raise ValueError("Such error")\ +""" + + +def test_insert_pytest_raises(pytester_pretty): + os.environ.pop('CI', None) + pytester_pretty.makeconftest(config) + test_file = pytester_pretty.makepyfile(default_test) + result = pytester_pretty.runpytest() + result.assert_outcomes(passed=2) + assert test_file.read_text() == ( + 'import re, pytest\n' + 'def test_ok():\n' + ' assert 1 + 2 == 3\n' + '\n' + 'def test_value_error(insert_pytest_raises):\n' + ' # with insert_pytest_raises():\n' + " with pytest.raises(ValueError, match=re.escape('Such error')):\n" + ' raise ValueError("Such error")' + ) + + +def test_insert_pytest_raises_no_pretty(pytester): + os.environ.pop('CI', None) + pytester.makeconftest(config) + test_file = pytester.makepyfile(default_test) + result = pytester.runpytest('-p', 'no:pretty') + result.assert_outcomes(passed=2) + assert test_file.read_text() == ( + 'import re, pytest\n' + 'def test_ok():\n' + ' assert 1 + 2 == 3\n' + '\n' + 'def test_value_error(insert_pytest_raises):\n' + ' # with insert_pytest_raises():\n' + " with pytest.raises(ValueError, match=re.escape('Such error')):\n" + ' raise ValueError("Such error")' + ) + + +def test_insert_pytest_raises_print(pytester_pretty, capsys): + os.environ.pop('CI', None) + pytester_pretty.makeconftest(config) + test_file = pytester_pretty.makepyfile(default_test) + # assert r == 0 + result = pytester_pretty.runpytest('--insert-assert-print') + result.assert_outcomes(passed=2) + assert test_file.read_text() == default_test + captured = capsys.readouterr() + assert 'test_insert_pytest_raises_print.py - 6:' in captured.out + assert 'Printed 1 insert_pytest_raises() call in 1 file\n' in captured.out + + +def test_insert_pytest_raises_fail(pytester_pretty): + os.environ.pop('CI', None) + pytester_pretty.makeconftest(config) + test_file = pytester_pretty.makepyfile(default_test) + # assert r == 0 + result = pytester_pretty.runpytest() + assert result.parseoutcomes() == {'passed': 2, 'warning': 1, 'insert': 1} + assert test_file.read_text() != default_test + + +def test_insert_pytest_raises_repeat(pytester_pretty, capsys): + os.environ.pop('CI', None) + pytester_pretty.makeconftest(config) + test_file = pytester_pretty.makepyfile( + """\ +import pytest, re + +@pytest.mark.parametrize('x', [1, 2, 3]) +def test_raise_keyerror(x, insert_pytest_raises): + with insert_pytest_raises(): + raise KeyError(x)\ +""" + ) + result = pytester_pretty.runpytest() + result.assert_outcomes(passed=3) + assert test_file.read_text() == ( + 'import pytest, re\n' + '\n' + "@pytest.mark.parametrize('x', [1, 2, 3])\n" + 'def test_raise_keyerror(x, insert_pytest_raises):\n' + ' # with insert_pytest_raises():\n' + " with pytest.raises(KeyError, match=re.escape('1')):\n" + ' raise KeyError(x)' + ) + captured = capsys.readouterr() + assert '2 inserts skipped because an assert statement on that line had already be inserted!\n' in captured.out + + +def test_insert_pytest_raises_frame_not_found(pytester_pretty, capsys): + os.environ.pop('CI', None) + pytester_pretty.makeconftest(config) + pytester_pretty.makepyfile( + """\ +def test_raise_keyerror(insert_pytest_raises): + eval('insert_pytest_raises().__enter__()') +""" + ) + result = pytester_pretty.runpytest() + result.assert_outcomes(failed=1) + captured = capsys.readouterr() + assert ( + 'RuntimeError: insert_pytest_raises() was unable to find the frame from which it was called\n' in captured.out + ) + + +def test_insert_pytest_raises_called_outside_with(pytester_pretty, capsys): + os.environ.pop('CI', None) + pytester_pretty.makeconftest(config) + pytester_pretty.makepyfile( + """\ +def test_raise_keyerror(insert_pytest_raises): + assert insert_pytest_raises().__enter__() == 1 +""" + ) + result = pytester_pretty.runpytest() + result.assert_outcomes(failed=1) + captured = capsys.readouterr() + assert "RuntimeError: insert_pytest_raises() was called outside of a 'with' statement\n" in captured.out + + +def test_insert_pytest_raises_called_with_other_with_statements(pytester_pretty, capsys): + os.environ.pop('CI', None) + pytester_pretty.makeconftest(config) + pytester_pretty.makepyfile( + """\ +import contextlib + +def test_raise_keyerror(insert_pytest_raises): + with contextlib.nullcontext(), insert_pytest_raises(): + raise KeyError(1) +""" + ) + result = pytester_pretty.runpytest() + result.assert_outcomes(failed=1) + captured = capsys.readouterr() + assert ( + 'RuntimeError: insert_pytest_raises() was called alongside other statements, this is not supported\n' + in captured.out + ) + + +def test_insert_pytest_raises_called_with_no_exception(pytester_pretty, capsys): + os.environ.pop('CI', None) + pytester_pretty.makeconftest(config) + pytester_pretty.makepyfile( + """\ +def test_raise_keyerror(insert_pytest_raises): + with insert_pytest_raises(): + assert True +""" + ) + result = pytester_pretty.runpytest() + result.assert_outcomes(failed=1) + captured = capsys.readouterr() + assert 'RuntimeError: insert_pytest_raises() was called but no exception was raised\n' in captured.out