diff --git a/devtools/pytest_plugin.py b/devtools/pytest_plugin.py index f80efd3..97ae439 100644 --- a/devtools/pytest_plugin.py +++ b/devtools/pytest_plugin.py @@ -37,15 +37,42 @@ class ToReplace: insert_assert_summary: ContextVar[list[str]] = ContextVar('insert_assert_summary') -def insert_assert(value: Any) -> int: +def sort_data_from_source(source: Any, value: Any) -> Any: + if isinstance(value, dict) and isinstance(source, dict): + new_dict = {} + used_keys = set() + for k, v in source.items(): + if k in value: + new_dict[k] = sort_data_from_source(v, value[k]) + used_keys.add(k) + for k, v in value.items(): + if k not in used_keys: + new_dict[k] = v + return new_dict + elif isinstance(value, list) and isinstance(source, list): + new_list: list[Any] = [] + for i, v in enumerate(value): + if i < len(source): + new_list.append(sort_data_from_source(source[i], v)) + else: + new_list.append(v) + return new_list + else: + return value + + +def insert_assert(value: Any, prev: Any = None) -> int: call_frame: FrameType = sys._getframe(1) if sys.version_info < (3, 8): # pragma: no cover raise RuntimeError('insert_assert() requires Python 3.8+') - + if prev: + use_value = sort_data_from_source(prev, value) + else: + use_value = value format_code = load_black() ex = Source.for_frame(call_frame).executing(call_frame) if ex.node is None: # pragma: no cover - python_code = format_code(str(custom_repr(value))) + python_code = format_code(str(custom_repr(use_value))) raise RuntimeError( f'insert_assert() was unable to find the frame from which it was called, called with:\n{python_code}' ) @@ -55,7 +82,7 @@ def insert_assert(value: Any) -> int: else: arg = ' '.join(map(str.strip, ex.source.asttokens().get_text(ast_arg).splitlines())) - python_code = format_code(f'# insert_assert({arg})\nassert {arg} == {custom_repr(value)}') + python_code = format_code(f'# insert_assert({arg})\nassert {arg} == {custom_repr(use_value)}') python_code = textwrap.indent(python_code, ex.node.col_offset * ' ') to_replace.append(ToReplace(Path(call_frame.f_code.co_filename), ex.node.lineno, ex.node.end_lineno, python_code)) diff --git a/tests/test_insert_assert.py b/tests/test_insert_assert.py index 299cee8..eee994c 100644 --- a/tests/test_insert_assert.py +++ b/tests/test_insert_assert.py @@ -167,3 +167,76 @@ def test_string_assert(x, insert_assert): ) captured = capsys.readouterr() assert '2 insert skipped because an assert statement on that line had already be inserted!\n' in captured.out + + +def test_insert_assert_sort_data(pytester_pretty): + os.environ.pop('CI', None) + pytester_pretty.makeconftest(config) + test_file = pytester_pretty.makepyfile( + """ +def test_dict(insert_assert): + old_data = { + "foo": 1, + "bar": [ + {"name": "Pydantic", "tags": ["validation", "json"]}, + {"name": "FastAPI", "description": "Web API framework in Python"}, + {"name": "SQLModel"}, + ], + "baz": 3, + } + new_data = { + "bar": [ + { + "description": "Data validation library", + "tags": ["validation", "json"], + "name": "Pydantic", + }, + {"description": "Web API framework in Python", "name": "FastAPI"}, + {"description": "DBs and Python", "name": "SQLModel"}, + {"name": "ARQ"}, + ], + "baz": 6, + "foo": 1, + } + insert_assert(new_data, old_data) +""" + ) + result = pytester_pretty.runpytest() + result.assert_outcomes(passed=1) + assert test_file.read_text() == ( + """def test_dict(insert_assert): + old_data = { + "foo": 1, + "bar": [ + {"name": "Pydantic", "tags": ["validation", "json"]}, + {"name": "FastAPI", "description": "Web API framework in Python"}, + {"name": "SQLModel"}, + ], + "baz": 3, + } + new_data = { + "bar": [ + { + "description": "Data validation library", + "tags": ["validation", "json"], + "name": "Pydantic", + }, + {"description": "Web API framework in Python", "name": "FastAPI"}, + {"description": "DBs and Python", "name": "SQLModel"}, + {"name": "ARQ"}, + ], + "baz": 6, + "foo": 1, + } + # insert_assert(new_data) + assert new_data == { + 'foo': 1, + 'bar': [ + {'name': 'Pydantic', 'tags': ['validation', 'json'], 'description': 'Data validation library'}, + {'name': 'FastAPI', 'description': 'Web API framework in Python'}, + {'name': 'SQLModel', 'description': 'DBs and Python'}, + {'name': 'ARQ'}, + ], + 'baz': 6, + }""" + )