|
| 1 | +from __future__ import annotations as _annotations |
| 2 | + |
| 3 | +import ast |
| 4 | +import builtins |
| 5 | +import sys |
| 6 | +import textwrap |
| 7 | +from contextvars import ContextVar |
| 8 | +from dataclasses import dataclass |
| 9 | +from enum import Enum |
| 10 | +from functools import lru_cache |
| 11 | +from itertools import groupby |
| 12 | +from pathlib import Path |
| 13 | +from types import FrameType |
| 14 | +from typing import TYPE_CHECKING, Any, Callable, Generator, Sized |
| 15 | + |
| 16 | +import pytest |
| 17 | +from executing import Source |
| 18 | + |
| 19 | +from . import debug |
| 20 | + |
| 21 | +if TYPE_CHECKING: |
| 22 | + pass |
| 23 | + |
| 24 | +__all__ = ('insert_assert',) |
| 25 | + |
| 26 | + |
| 27 | +@dataclass |
| 28 | +class ToReplace: |
| 29 | + file: Path |
| 30 | + start_line: int |
| 31 | + end_line: int | None |
| 32 | + code: str |
| 33 | + |
| 34 | + |
| 35 | +to_replace: list[ToReplace] = [] |
| 36 | +insert_assert_calls: ContextVar[int] = ContextVar('insert_assert_calls', default=0) |
| 37 | +insert_assert_summary: ContextVar[list[str]] = ContextVar('insert_assert_summary') |
| 38 | + |
| 39 | + |
| 40 | +def insert_assert(value: Any) -> int: |
| 41 | + call_frame: FrameType = sys._getframe(1) |
| 42 | + if sys.version_info < (3, 8): # pragma: no cover |
| 43 | + raise RuntimeError('insert_assert() requires Python 3.8+') |
| 44 | + |
| 45 | + format_code = load_black() |
| 46 | + ex = Source.for_frame(call_frame).executing(call_frame) |
| 47 | + if ex.node is None: # pragma: no cover |
| 48 | + python_code = format_code(str(custom_repr(value))) |
| 49 | + raise RuntimeError( |
| 50 | + f'insert_assert() was unable to find the frame from which it was called, called with:\n{python_code}' |
| 51 | + ) |
| 52 | + ast_arg = ex.node.args[0] # type: ignore[attr-defined] |
| 53 | + if isinstance(ast_arg, ast.Name): |
| 54 | + arg = ast_arg.id |
| 55 | + else: |
| 56 | + arg = ' '.join(map(str.strip, ex.source.asttokens().get_text(ast_arg).splitlines())) |
| 57 | + |
| 58 | + python_code = format_code(f'# insert_assert({arg})\nassert {arg} == {custom_repr(value)}') |
| 59 | + |
| 60 | + python_code = textwrap.indent(python_code, ex.node.col_offset * ' ') |
| 61 | + to_replace.append(ToReplace(Path(call_frame.f_code.co_filename), ex.node.lineno, ex.node.end_lineno, python_code)) |
| 62 | + calls = insert_assert_calls.get() + 1 |
| 63 | + insert_assert_calls.set(calls) |
| 64 | + return calls |
| 65 | + |
| 66 | + |
| 67 | +def pytest_addoption(parser: Any) -> None: |
| 68 | + parser.addoption( |
| 69 | + '--insert-assert-print', |
| 70 | + action='store_true', |
| 71 | + default=False, |
| 72 | + help='Print statements that would be substituted for insert_assert(), instead of writing to files', |
| 73 | + ) |
| 74 | + parser.addoption( |
| 75 | + '--insert-assert-fail', |
| 76 | + action='store_true', |
| 77 | + default=False, |
| 78 | + help='Fail tests which include one or more insert_assert() calls', |
| 79 | + ) |
| 80 | + |
| 81 | + |
| 82 | +@pytest.fixture(scope='session', autouse=True) |
| 83 | +def insert_assert_add_to_builtins() -> None: |
| 84 | + try: |
| 85 | + setattr(builtins, 'insert_assert', insert_assert) |
| 86 | + # we also install debug here since the default script doesn't install it |
| 87 | + setattr(builtins, 'debug', debug) |
| 88 | + except TypeError: |
| 89 | + # happens on pypy |
| 90 | + pass |
| 91 | + |
| 92 | + |
| 93 | +@pytest.fixture(autouse=True) |
| 94 | +def insert_assert_maybe_fail(pytestconfig: pytest.Config) -> Generator[None, None, None]: |
| 95 | + insert_assert_calls.set(0) |
| 96 | + yield |
| 97 | + print_instead = pytestconfig.getoption('insert_assert_print') |
| 98 | + if not print_instead: |
| 99 | + count = insert_assert_calls.get() |
| 100 | + if count: |
| 101 | + pytest.fail(f'devtools-insert-assert: {count} assert{plural(count)} will be inserted', pytrace=False) |
| 102 | + |
| 103 | + |
| 104 | +@pytest.fixture(name='insert_assert') |
| 105 | +def insert_assert_fixture() -> Callable[[Any], int]: |
| 106 | + return insert_assert |
| 107 | + |
| 108 | + |
| 109 | +def pytest_report_teststatus(report: pytest.TestReport, config: pytest.Config) -> Any: |
| 110 | + if report.when == 'teardown' and report.failed and 'devtools-insert-assert:' in repr(report.longrepr): |
| 111 | + return 'insert assert', 'i', ('INSERT ASSERT', {'cyan': True}) |
| 112 | + |
| 113 | + |
| 114 | +@pytest.fixture(scope='session', autouse=True) |
| 115 | +def insert_assert_session(pytestconfig: pytest.Config) -> Generator[None, None, None]: |
| 116 | + """ |
| 117 | + Actual logic for updating code examples. |
| 118 | + """ |
| 119 | + try: |
| 120 | + __builtins__['insert_assert'] = insert_assert |
| 121 | + except TypeError: |
| 122 | + # happens on pypy |
| 123 | + pass |
| 124 | + |
| 125 | + yield |
| 126 | + |
| 127 | + if not to_replace: |
| 128 | + return None |
| 129 | + |
| 130 | + print_instead = pytestconfig.getoption('insert_assert_print') |
| 131 | + |
| 132 | + highlight = None |
| 133 | + if print_instead: |
| 134 | + highlight = get_pygments() |
| 135 | + |
| 136 | + files = 0 |
| 137 | + dup_count = 0 |
| 138 | + summary = [] |
| 139 | + for file, group in groupby(to_replace, key=lambda tr: tr.file): |
| 140 | + # we have to substitute lines in reverse order to avoid messing up line numbers |
| 141 | + lines = file.read_text().splitlines() |
| 142 | + duplicates: set[int] = set() |
| 143 | + for tr in sorted(group, key=lambda x: x.start_line, reverse=True): |
| 144 | + if print_instead: |
| 145 | + hr = '-' * 80 |
| 146 | + code = highlight(tr.code) if highlight else tr.code |
| 147 | + line_no = f'{tr.start_line}' if tr.start_line == tr.end_line else f'{tr.start_line}-{tr.end_line}' |
| 148 | + summary.append(f'{file} - {line_no}:\n{hr}\n{code}{hr}\n') |
| 149 | + else: |
| 150 | + if tr.start_line in duplicates: |
| 151 | + dup_count += 1 |
| 152 | + else: |
| 153 | + duplicates.add(tr.start_line) |
| 154 | + lines[tr.start_line - 1 : tr.end_line] = tr.code.splitlines() |
| 155 | + if not print_instead: |
| 156 | + file.write_text('\n'.join(lines)) |
| 157 | + files += 1 |
| 158 | + prefix = 'Printed' if print_instead else 'Replaced' |
| 159 | + summary.append( |
| 160 | + f'{prefix} {len(to_replace)} insert_assert() call{plural(to_replace)} in {files} file{plural(files)}' |
| 161 | + ) |
| 162 | + if dup_count: |
| 163 | + summary.append( |
| 164 | + f'\n{dup_count} insert skipped because an assert statement on that line had already be inserted!' |
| 165 | + ) |
| 166 | + |
| 167 | + insert_assert_summary.set(summary) |
| 168 | + to_replace.clear() |
| 169 | + |
| 170 | + |
| 171 | +def pytest_terminal_summary() -> None: |
| 172 | + summary = insert_assert_summary.get(None) |
| 173 | + if summary: |
| 174 | + print('\n'.join(summary)) |
| 175 | + |
| 176 | + |
| 177 | +def custom_repr(value: Any) -> Any: |
| 178 | + if isinstance(value, (list, tuple, set, frozenset)): |
| 179 | + return value.__class__(map(custom_repr, value)) |
| 180 | + elif isinstance(value, dict): |
| 181 | + return value.__class__((custom_repr(k), custom_repr(v)) for k, v in value.items()) |
| 182 | + if isinstance(value, Enum): |
| 183 | + return PlainRepr(f'{value.__class__.__name__}.{value.name}') |
| 184 | + else: |
| 185 | + return PlainRepr(repr(value)) |
| 186 | + |
| 187 | + |
| 188 | +class PlainRepr(str): |
| 189 | + """ |
| 190 | + String class where repr doesn't include quotes. |
| 191 | + """ |
| 192 | + |
| 193 | + def __repr__(self) -> str: |
| 194 | + return str(self) |
| 195 | + |
| 196 | + |
| 197 | +def plural(v: int | Sized) -> str: |
| 198 | + if isinstance(v, (int, float)): |
| 199 | + n = v |
| 200 | + else: |
| 201 | + n = len(v) |
| 202 | + return '' if n == 1 else 's' |
| 203 | + |
| 204 | + |
| 205 | +@lru_cache(maxsize=None) |
| 206 | +def load_black() -> Callable[[str], str]: |
| 207 | + """ |
| 208 | + Build black configuration from "pyproject.toml". |
| 209 | +
|
| 210 | + Black doesn't have a nice self-contained API for reading pyproject.toml, hence all this. |
| 211 | + """ |
| 212 | + try: |
| 213 | + from black import format_file_contents |
| 214 | + from black.files import find_pyproject_toml, parse_pyproject_toml |
| 215 | + from black.mode import Mode, TargetVersion |
| 216 | + from black.parsing import InvalidInput |
| 217 | + except ImportError: |
| 218 | + return lambda x: x |
| 219 | + |
| 220 | + def convert_target_version(target_version_config: Any) -> set[Any] | None: |
| 221 | + if target_version_config is not None: |
| 222 | + return None |
| 223 | + elif not isinstance(target_version_config, list): |
| 224 | + raise ValueError('Config key "target_version" must be a list') |
| 225 | + else: |
| 226 | + return {TargetVersion[tv.upper()] for tv in target_version_config} |
| 227 | + |
| 228 | + @dataclass |
| 229 | + class ConfigArg: |
| 230 | + config_name: str |
| 231 | + keyword_name: str |
| 232 | + converter: Callable[[Any], Any] |
| 233 | + |
| 234 | + config_mapping: list[ConfigArg] = [ |
| 235 | + ConfigArg('target_version', 'target_versions', convert_target_version), |
| 236 | + ConfigArg('line_length', 'line_length', int), |
| 237 | + ConfigArg('skip_string_normalization', 'string_normalization', lambda x: not x), |
| 238 | + ConfigArg('skip_magic_trailing_commas', 'magic_trailing_comma', lambda x: not x), |
| 239 | + ] |
| 240 | + |
| 241 | + config_str = find_pyproject_toml((str(Path.cwd()),)) |
| 242 | + mode_ = None |
| 243 | + fast = False |
| 244 | + if config_str: |
| 245 | + try: |
| 246 | + config = parse_pyproject_toml(config_str) |
| 247 | + except (OSError, ValueError) as e: |
| 248 | + raise ValueError(f'Error reading configuration file: {e}') |
| 249 | + |
| 250 | + if config: |
| 251 | + kwargs = dict() |
| 252 | + for config_arg in config_mapping: |
| 253 | + try: |
| 254 | + value = config[config_arg.config_name] |
| 255 | + except KeyError: |
| 256 | + pass |
| 257 | + else: |
| 258 | + value = config_arg.converter(value) |
| 259 | + if value is not None: |
| 260 | + kwargs[config_arg.keyword_name] = value |
| 261 | + |
| 262 | + mode_ = Mode(**kwargs) |
| 263 | + fast = bool(config.get('fast')) |
| 264 | + |
| 265 | + mode = mode_ or Mode() |
| 266 | + |
| 267 | + def format_code(code: str) -> str: |
| 268 | + try: |
| 269 | + return format_file_contents(code, fast=fast, mode=mode) |
| 270 | + except InvalidInput as e: |
| 271 | + print('black error, you will need to format the code manually,', e) |
| 272 | + return code |
| 273 | + |
| 274 | + return format_code |
| 275 | + |
| 276 | + |
| 277 | +# isatty() is false inside pytest, hence calling this now |
| 278 | +try: |
| 279 | + std_out_istty = sys.stdout.isatty() |
| 280 | +except Exception: |
| 281 | + std_out_istty = False |
| 282 | + |
| 283 | + |
| 284 | +@lru_cache(maxsize=None) |
| 285 | +def get_pygments() -> Callable[[str], str] | None: # pragma: no cover |
| 286 | + if not std_out_istty: |
| 287 | + return None |
| 288 | + try: |
| 289 | + import pygments |
| 290 | + from pygments.formatters import Terminal256Formatter |
| 291 | + from pygments.lexers import PythonLexer |
| 292 | + except ImportError as e: # pragma: no cover |
| 293 | + print(e) |
| 294 | + return None |
| 295 | + else: |
| 296 | + pyg_lexer, terminal_formatter = PythonLexer(), Terminal256Formatter() |
| 297 | + |
| 298 | + def highlight(code: str) -> str: |
| 299 | + return pygments.highlight(code, lexer=pyg_lexer, formatter=terminal_formatter) |
| 300 | + |
| 301 | + return highlight |
0 commit comments