Skip to content

Commit 61c6b67

Browse files
authored
Insert assert (#126)
* support displaying ast types * support 3.7 & 3.8 * skip tests on older python * add insert_assert pytest fixture * use newest pytest-pretty * try to fix CI * fix mypy and black * add pytest to for mypy * fix mypy * change code to install debug in fixture * tweak install instructions
1 parent f0e0fb2 commit 61c6b67

File tree

11 files changed

+531
-22
lines changed

11 files changed

+531
-22
lines changed

.github/workflows/ci.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ on:
88
- '**'
99
pull_request: {}
1010

11+
env:
12+
COLUMNS: 150
13+
1114
jobs:
1215
lint:
1316
runs-on: ubuntu-latest
@@ -51,7 +54,7 @@ jobs:
5154
python-version: ${{ matrix.python-version }}
5255

5356
- run: pip install -r requirements/testing.txt -r requirements/pyproject.txt
54-
- run: pip install .
57+
5558
- run: pip freeze
5659

5760
- name: test with extras

devtools/__main__.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import builtins
2-
import os
32
import sys
43
from pathlib import Path
54

@@ -8,13 +7,17 @@
87
# language=python
98
install_code = """
109
# add devtools `debug` function to builtins
11-
import builtins
12-
try:
13-
from devtools import debug
14-
except ImportError:
15-
pass
16-
else:
17-
setattr(builtins, 'debug', debug)
10+
import sys
11+
# we don't install here for pytest as it breaks pytest, it is
12+
# installed later by a pytest fixture
13+
if not sys.argv[0].endswith('pytest'):
14+
import builtins
15+
try:
16+
from devtools import debug
17+
except ImportError:
18+
pass
19+
else:
20+
setattr(builtins, 'debug', debug)
1821
"""
1922

2023

@@ -47,11 +50,11 @@ def install() -> int:
4750

4851
print(f'Found path "{install_path}" to install devtools into __builtins__')
4952
print('To install devtools, run the following command:\n')
50-
if os.access(install_path, os.W_OK):
51-
print(f' python -m devtools print-code >> {install_path}\n')
52-
else:
53+
print(f' python -m devtools print-code >> {install_path}\n')
54+
if not install_path.is_relative_to(Path.home()):
55+
print('or maybe\n')
5356
print(f' python -m devtools print-code | sudo tee -a {install_path} > /dev/null\n')
54-
print('Note: "sudo" is required because the path is not writable by the current user.')
57+
print('Note: "sudo" might be required because the path is in your home directory.')
5558

5659
return 0
5760

devtools/prettier.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ class SkipPretty(Exception):
4444
@cache
4545
def get_pygments() -> 'Tuple[Any, Any, Any]':
4646
try:
47-
import pygments # type: ignore
48-
from pygments.formatters import Terminal256Formatter # type: ignore
49-
from pygments.lexers import PythonLexer # type: ignore
47+
import pygments
48+
from pygments.formatters import Terminal256Formatter
49+
from pygments.lexers import PythonLexer
5050
except ImportError: # pragma: no cover
5151
return None, None, None
5252
else:

devtools/pytest_plugin.py

Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
from __future__ import annotations as _annotations
2+
3+
import ast
4+
import builtins
5+
import sys
6+
import textwrap
7+
from contextvars import ContextVar
8+
from dataclasses import dataclass
9+
from enum import Enum
10+
from functools import lru_cache
11+
from itertools import groupby
12+
from pathlib import Path
13+
from types import FrameType
14+
from typing import TYPE_CHECKING, Any, Callable, Generator, Sized
15+
16+
import pytest
17+
from executing import Source
18+
19+
from . import debug
20+
21+
if TYPE_CHECKING:
22+
pass
23+
24+
__all__ = ('insert_assert',)
25+
26+
27+
@dataclass
28+
class ToReplace:
29+
file: Path
30+
start_line: int
31+
end_line: int | None
32+
code: str
33+
34+
35+
to_replace: list[ToReplace] = []
36+
insert_assert_calls: ContextVar[int] = ContextVar('insert_assert_calls', default=0)
37+
insert_assert_summary: ContextVar[list[str]] = ContextVar('insert_assert_summary')
38+
39+
40+
def insert_assert(value: Any) -> int:
41+
call_frame: FrameType = sys._getframe(1)
42+
if sys.version_info < (3, 8): # pragma: no cover
43+
raise RuntimeError('insert_assert() requires Python 3.8+')
44+
45+
format_code = load_black()
46+
ex = Source.for_frame(call_frame).executing(call_frame)
47+
if ex.node is None: # pragma: no cover
48+
python_code = format_code(str(custom_repr(value)))
49+
raise RuntimeError(
50+
f'insert_assert() was unable to find the frame from which it was called, called with:\n{python_code}'
51+
)
52+
ast_arg = ex.node.args[0] # type: ignore[attr-defined]
53+
if isinstance(ast_arg, ast.Name):
54+
arg = ast_arg.id
55+
else:
56+
arg = ' '.join(map(str.strip, ex.source.asttokens().get_text(ast_arg).splitlines()))
57+
58+
python_code = format_code(f'# insert_assert({arg})\nassert {arg} == {custom_repr(value)}')
59+
60+
python_code = textwrap.indent(python_code, ex.node.col_offset * ' ')
61+
to_replace.append(ToReplace(Path(call_frame.f_code.co_filename), ex.node.lineno, ex.node.end_lineno, python_code))
62+
calls = insert_assert_calls.get() + 1
63+
insert_assert_calls.set(calls)
64+
return calls
65+
66+
67+
def pytest_addoption(parser: Any) -> None:
68+
parser.addoption(
69+
'--insert-assert-print',
70+
action='store_true',
71+
default=False,
72+
help='Print statements that would be substituted for insert_assert(), instead of writing to files',
73+
)
74+
parser.addoption(
75+
'--insert-assert-fail',
76+
action='store_true',
77+
default=False,
78+
help='Fail tests which include one or more insert_assert() calls',
79+
)
80+
81+
82+
@pytest.fixture(scope='session', autouse=True)
83+
def insert_assert_add_to_builtins() -> None:
84+
try:
85+
setattr(builtins, 'insert_assert', insert_assert)
86+
# we also install debug here since the default script doesn't install it
87+
setattr(builtins, 'debug', debug)
88+
except TypeError:
89+
# happens on pypy
90+
pass
91+
92+
93+
@pytest.fixture(autouse=True)
94+
def insert_assert_maybe_fail(pytestconfig: pytest.Config) -> Generator[None, None, None]:
95+
insert_assert_calls.set(0)
96+
yield
97+
print_instead = pytestconfig.getoption('insert_assert_print')
98+
if not print_instead:
99+
count = insert_assert_calls.get()
100+
if count:
101+
pytest.fail(f'devtools-insert-assert: {count} assert{plural(count)} will be inserted', pytrace=False)
102+
103+
104+
@pytest.fixture(name='insert_assert')
105+
def insert_assert_fixture() -> Callable[[Any], int]:
106+
return insert_assert
107+
108+
109+
def pytest_report_teststatus(report: pytest.TestReport, config: pytest.Config) -> Any:
110+
if report.when == 'teardown' and report.failed and 'devtools-insert-assert:' in repr(report.longrepr):
111+
return 'insert assert', 'i', ('INSERT ASSERT', {'cyan': True})
112+
113+
114+
@pytest.fixture(scope='session', autouse=True)
115+
def insert_assert_session(pytestconfig: pytest.Config) -> Generator[None, None, None]:
116+
"""
117+
Actual logic for updating code examples.
118+
"""
119+
try:
120+
__builtins__['insert_assert'] = insert_assert
121+
except TypeError:
122+
# happens on pypy
123+
pass
124+
125+
yield
126+
127+
if not to_replace:
128+
return None
129+
130+
print_instead = pytestconfig.getoption('insert_assert_print')
131+
132+
highlight = None
133+
if print_instead:
134+
highlight = get_pygments()
135+
136+
files = 0
137+
dup_count = 0
138+
summary = []
139+
for file, group in groupby(to_replace, key=lambda tr: tr.file):
140+
# we have to substitute lines in reverse order to avoid messing up line numbers
141+
lines = file.read_text().splitlines()
142+
duplicates: set[int] = set()
143+
for tr in sorted(group, key=lambda x: x.start_line, reverse=True):
144+
if print_instead:
145+
hr = '-' * 80
146+
code = highlight(tr.code) if highlight else tr.code
147+
line_no = f'{tr.start_line}' if tr.start_line == tr.end_line else f'{tr.start_line}-{tr.end_line}'
148+
summary.append(f'{file} - {line_no}:\n{hr}\n{code}{hr}\n')
149+
else:
150+
if tr.start_line in duplicates:
151+
dup_count += 1
152+
else:
153+
duplicates.add(tr.start_line)
154+
lines[tr.start_line - 1 : tr.end_line] = tr.code.splitlines()
155+
if not print_instead:
156+
file.write_text('\n'.join(lines))
157+
files += 1
158+
prefix = 'Printed' if print_instead else 'Replaced'
159+
summary.append(
160+
f'{prefix} {len(to_replace)} insert_assert() call{plural(to_replace)} in {files} file{plural(files)}'
161+
)
162+
if dup_count:
163+
summary.append(
164+
f'\n{dup_count} insert skipped because an assert statement on that line had already be inserted!'
165+
)
166+
167+
insert_assert_summary.set(summary)
168+
to_replace.clear()
169+
170+
171+
def pytest_terminal_summary() -> None:
172+
summary = insert_assert_summary.get(None)
173+
if summary:
174+
print('\n'.join(summary))
175+
176+
177+
def custom_repr(value: Any) -> Any:
178+
if isinstance(value, (list, tuple, set, frozenset)):
179+
return value.__class__(map(custom_repr, value))
180+
elif isinstance(value, dict):
181+
return value.__class__((custom_repr(k), custom_repr(v)) for k, v in value.items())
182+
if isinstance(value, Enum):
183+
return PlainRepr(f'{value.__class__.__name__}.{value.name}')
184+
else:
185+
return PlainRepr(repr(value))
186+
187+
188+
class PlainRepr(str):
189+
"""
190+
String class where repr doesn't include quotes.
191+
"""
192+
193+
def __repr__(self) -> str:
194+
return str(self)
195+
196+
197+
def plural(v: int | Sized) -> str:
198+
if isinstance(v, (int, float)):
199+
n = v
200+
else:
201+
n = len(v)
202+
return '' if n == 1 else 's'
203+
204+
205+
@lru_cache(maxsize=None)
206+
def load_black() -> Callable[[str], str]:
207+
"""
208+
Build black configuration from "pyproject.toml".
209+
210+
Black doesn't have a nice self-contained API for reading pyproject.toml, hence all this.
211+
"""
212+
try:
213+
from black import format_file_contents
214+
from black.files import find_pyproject_toml, parse_pyproject_toml
215+
from black.mode import Mode, TargetVersion
216+
from black.parsing import InvalidInput
217+
except ImportError:
218+
return lambda x: x
219+
220+
def convert_target_version(target_version_config: Any) -> set[Any] | None:
221+
if target_version_config is not None:
222+
return None
223+
elif not isinstance(target_version_config, list):
224+
raise ValueError('Config key "target_version" must be a list')
225+
else:
226+
return {TargetVersion[tv.upper()] for tv in target_version_config}
227+
228+
@dataclass
229+
class ConfigArg:
230+
config_name: str
231+
keyword_name: str
232+
converter: Callable[[Any], Any]
233+
234+
config_mapping: list[ConfigArg] = [
235+
ConfigArg('target_version', 'target_versions', convert_target_version),
236+
ConfigArg('line_length', 'line_length', int),
237+
ConfigArg('skip_string_normalization', 'string_normalization', lambda x: not x),
238+
ConfigArg('skip_magic_trailing_commas', 'magic_trailing_comma', lambda x: not x),
239+
]
240+
241+
config_str = find_pyproject_toml((str(Path.cwd()),))
242+
mode_ = None
243+
fast = False
244+
if config_str:
245+
try:
246+
config = parse_pyproject_toml(config_str)
247+
except (OSError, ValueError) as e:
248+
raise ValueError(f'Error reading configuration file: {e}')
249+
250+
if config:
251+
kwargs = dict()
252+
for config_arg in config_mapping:
253+
try:
254+
value = config[config_arg.config_name]
255+
except KeyError:
256+
pass
257+
else:
258+
value = config_arg.converter(value)
259+
if value is not None:
260+
kwargs[config_arg.keyword_name] = value
261+
262+
mode_ = Mode(**kwargs)
263+
fast = bool(config.get('fast'))
264+
265+
mode = mode_ or Mode()
266+
267+
def format_code(code: str) -> str:
268+
try:
269+
return format_file_contents(code, fast=fast, mode=mode)
270+
except InvalidInput as e:
271+
print('black error, you will need to format the code manually,', e)
272+
return code
273+
274+
return format_code
275+
276+
277+
# isatty() is false inside pytest, hence calling this now
278+
try:
279+
std_out_istty = sys.stdout.isatty()
280+
except Exception:
281+
std_out_istty = False
282+
283+
284+
@lru_cache(maxsize=None)
285+
def get_pygments() -> Callable[[str], str] | None: # pragma: no cover
286+
if not std_out_istty:
287+
return None
288+
try:
289+
import pygments
290+
from pygments.formatters import Terminal256Formatter
291+
from pygments.lexers import PythonLexer
292+
except ImportError as e: # pragma: no cover
293+
print(e)
294+
return None
295+
else:
296+
pyg_lexer, terminal_formatter = PythonLexer(), Terminal256Formatter()
297+
298+
def highlight(code: str) -> str:
299+
return pygments.highlight(code, lexer=pyg_lexer, formatter=terminal_formatter)
300+
301+
return highlight

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ Funding = 'https://github.com/sponsors/samuelcolvin'
4848
Source = 'https://github.com/samuelcolvin/python-devtools'
4949
Changelog = 'https://github.com/samuelcolvin/python-devtools/releases'
5050

51+
[project.entry-points.pytest11]
52+
devtools = 'devtools.pytest_plugin'
53+
5154
[tool.pytest.ini_options]
5255
testpaths = 'tests'
5356
filterwarnings = 'error'
@@ -90,5 +93,5 @@ strict = true
9093
warn_return_any = false
9194

9295
[[tool.mypy.overrides]]
93-
module = ['executing.*']
96+
module = ['executing.*', 'pygments.*']
9497
ignore_missing_imports = true

0 commit comments

Comments
 (0)