diff --git a/tools/stronghold/src/api/__init__.py b/tools/stronghold/src/api/__init__.py index 5ed07375f1..a0c039a7f4 100644 --- a/tools/stronghold/src/api/__init__.py +++ b/tools/stronghold/src/api/__init__.py @@ -3,7 +3,7 @@ from __future__ import annotations import dataclasses -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from typing import Optional import api.types @@ -41,3 +41,32 @@ class Parameter: line: int # Type annotation (relies on ast.annotation types) type_annotation: Optional[api.types.TypeHint] = None + + +@dataclasses.dataclass +class Field: + """Represents a dataclass or class attribute.""" + + name: str + required: bool + line: int + type_annotation: Optional[api.types.TypeHint] = None + # Whether the field participates in the dataclass __init__ (dataclasses only) + init: bool = True + + +@dataclasses.dataclass +class Class: + """Represents a class or dataclass.""" + + fields: Sequence[Field] + line: int + dataclass: bool = False + + +@dataclasses.dataclass +class API: + """Represents extracted API information.""" + + functions: Mapping[str, Parameters] + classes: Mapping[str, Class] diff --git a/tools/stronghold/src/api/ast.py b/tools/stronghold/src/api/ast.py index 51a6a57d51..6b088770b6 100644 --- a/tools/stronghold/src/api/ast.py +++ b/tools/stronghold/src/api/ast.py @@ -11,28 +11,27 @@ import api.types -def extract(path: pathlib.Path) -> Mapping[str, api.Parameters]: - """Extracts the API from a given source file. - - The keys will be the fully-qualified path from the root of the module, e.g. - * global_func - * ClassName.method_name - * ClassName.SubClassName.method_name - """ - raw_api = extract_raw(path) - return { - name: _function_def_to_parameters(function_def) - for name, function_def in raw_api.items() +def extract(path: pathlib.Path, *, include_classes: bool = False) -> api.API: + """Extracts API definitions from a given source file.""" + + funcs, classes = extract_raw(path, include_classes=include_classes) + parameters = { + name: _function_def_to_parameters(func) for name, func in funcs.items() } + return api.API(functions=parameters, classes=classes) + +def extract_raw( + path: pathlib.Path, *, include_classes: bool = False +) -> tuple[Mapping[str, ast.FunctionDef], Mapping[str, api.Class]]: + """Extracts API as AST nodes.""" -def extract_raw(path: pathlib.Path) -> Mapping[str, ast.FunctionDef]: - """Extracts the API as ast.FunctionDef instances.""" - out: dict[str, ast.FunctionDef] = {} - _ContextualNodeVisitor(out, context=[]).visit( + funcs: dict[str, ast.FunctionDef] = {} + classes: dict[str, api.Class] = {} + _ContextualNodeVisitor(funcs, classes if include_classes else None, []).visit( ast.parse(path.read_text(), os.fspath(path)) ) - return out + return funcs, classes def _function_def_to_parameters(node: ast.FunctionDef) -> api.Parameters: @@ -90,20 +89,105 @@ def _function_def_to_parameters(node: ast.FunctionDef) -> api.Parameters: class _ContextualNodeVisitor(ast.NodeVisitor): - """NodeVisitor implementation that tracks which class, if any, it is a member of.""" - - def __init__(self, out: dict[str, ast.FunctionDef], context: Sequence[str]) -> None: - self._out = out - self._context = context + """NodeVisitor that collects functions and optionally classes.""" + + def __init__( + self, + functions: dict[str, ast.FunctionDef], + classes: dict[str, api.Class] | None, + context: Sequence[str], + ) -> None: + self._functions = functions + self._classes = classes + self._context = list(context) def visit_ClassDef(self, node: ast.ClassDef) -> None: # Recursively visit all nodes under this class, with the given # class name pushed onto a new context. + if self._classes is not None: + name = ".".join(self._context + [node.name]) + is_dataclass = any( + (isinstance(dec, ast.Name) and dec.id == "dataclass") + or (isinstance(dec, ast.Attribute) and dec.attr == "dataclass") + for dec in node.decorator_list + ) + fields: list[api.Field] = [] + for stmt in node.body: + if isinstance(stmt, ast.AnnAssign) and isinstance( + stmt.target, ast.Name + ): + field_name = stmt.target.id + if field_name.startswith("_"): + continue + required = stmt.value is None + init = True + # Support dataclasses.field(...) + if isinstance(stmt.value, ast.Call): + fn = stmt.value.func + + def _is_field_func(f: ast.AST) -> bool: + return (isinstance(f, ast.Name) and f.id == "field") or ( + isinstance(f, ast.Attribute) and f.attr == "field" + ) + + if _is_field_func(fn): + # default/default_factory imply not required + has_default = any( + isinstance(kw, ast.keyword) + and kw.arg == "default" + and kw.value is not None + for kw in stmt.value.keywords + ) + has_default_factory = any( + isinstance(kw, ast.keyword) + and kw.arg == "default_factory" + and kw.value is not None + for kw in stmt.value.keywords + ) + required = not (has_default or has_default_factory) + # init flag + for kw in stmt.value.keywords: + if isinstance(kw, ast.keyword) and kw.arg == "init": + init = not ( + isinstance(kw.value, ast.Constant) + and kw.value.value is False + ) + break + fields.append( + api.Field( + name=field_name, + required=required, + line=stmt.lineno, + type_annotation=api.types.annotation_to_dataclass( + stmt.annotation + ), + init=init, + ) + ) + elif isinstance(stmt, ast.Assign): + for target in stmt.targets: + if isinstance(target, ast.Name): + field_name = target.id + if field_name.startswith("_"): + continue + fields.append( + api.Field( + name=field_name, + required=False, + line=stmt.lineno, + type_annotation=None, + init=True, + ) + ) + self._classes[name] = api.Class( + fields=fields, line=node.lineno, dataclass=is_dataclass + ) + _ContextualNodeVisitor( - self._out, list(self._context) + [node.name] + self._functions, self._classes, self._context + [node.name] ).generic_visit(node) def visit_FunctionDef(self, node: ast.FunctionDef) -> None: # Records this function. - name = ".".join(list(self._context) + [node.name]) - self._out[name] = node + name = ".".join(self._context + [node.name]) + self._functions[name] = node diff --git a/tools/stronghold/src/api/compatibility.py b/tools/stronghold/src/api/compatibility.py index 5dcfc488a2..4d95c2fe5a 100644 --- a/tools/stronghold/src/api/compatibility.py +++ b/tools/stronghold/src/api/compatibility.py @@ -68,15 +68,39 @@ def check( before: pathlib.Path, after: pathlib.Path ) -> Sequence[api.violations.Violation]: """Identifies API compatibility issues between two files.""" - before_api = api.ast.extract(before) - after_api = api.ast.extract(after) + before_api = api.ast.extract(before, include_classes=True) + after_api = api.ast.extract(after, include_classes=True) + before_funcs = before_api.functions + after_funcs = after_api.functions + before_classes = before_api.classes + after_classes = after_api.classes + + # Identify deleted classes to avoid double-reporting their methods as deleted + deleted_classes = { + name + for name in before_classes + if not any(token.startswith("_") for token in name.split(".")) + and name not in after_classes + } + + def _under_deleted_class(func_name: str) -> bool: + tokens = func_name.split(".") + prefix = [] + for t in tokens[:-1]: # exclude the function name itself + prefix.append(t) + if ".".join(prefix) in deleted_classes: + return True + return False violations: list[api.violations.Violation] = [] - for name, before_def in before_api.items(): + for name, before_def in before_funcs.items(): if any(token.startswith("_") for token in name.split(".")): continue + if _under_deleted_class(name): + # Will be reported as a class deletion instead + continue - after_def = after_api.get(name) + after_def = after_funcs.get(name) if after_def is None: violations.append(api.violations.FunctionDeleted(func=name, line=1)) continue @@ -103,6 +127,17 @@ def check( violations += _check_by_requiredness(name, before_def, after_def) violations += _check_variadic_parameters(name, before_def, after_def) + for name, before_class in before_classes.items(): + if any(token.startswith("_") for token in name.split(".")): + continue + after_class = after_classes.get(name) + if after_class is None: + continue + violations += list(_check_class_fields(name, before_class, after_class)) + + # Classes deleted between before and after + violations += list(_check_deleted_classes(before_classes, after_classes)) + return violations @@ -247,7 +282,68 @@ def _check_variadic_parameters( if before.variadic_args and not after.variadic_args: yield api.violations.VarArgsDeleted(func=func, line=after.line) if before.variadic_kwargs and not after.variadic_kwargs: - yield api.violations.KwArgsDeleted(func, line=after.line) + yield api.violations.KwArgsDeleted(func=func, line=after.line) + + +def _check_class_fields( + cls: str, before: api.Class, after: api.Class +) -> Iterable[api.violations.Violation]: + """Checks class and dataclass field compatibility.""" + + before_fields = {f.name: f for f in before.fields} + after_fields = {f.name: f for f in after.fields} + + for name, before_field in before_fields.items(): + after_field = after_fields.get(name) + if after_field is None: + yield api.violations.FieldRemoved(func=cls, parameter=name, line=after.line) + continue + + if not _check_type_compatibility( + before_field.type_annotation, after_field.type_annotation + ): + yield api.violations.FieldTypeChanged( + func=cls, + parameter=name, + line=after_field.line, + type_before=str(before_field.type_annotation), + type_after=str(after_field.type_annotation), + ) + + for name in set(after_fields) - set(before_fields): + # Policy: Only flag additions for dataclasses, and only when required and in __init__ + if after.dataclass: + added = after_fields[name] + if added.required and getattr(added, "init", True): + yield api.violations.FieldAdded( + func=cls, parameter=name, line=added.line + ) + + +def _check_deleted_classes( + before_classes: Mapping[str, api.Class], after_classes: Mapping[str, api.Class] +) -> Iterable[api.violations.Violation]: + """Emits violations for classes deleted between before and after.""" + deleted = [ + name + for name in before_classes + if not any(token.startswith("_") for token in name.split(".")) + and name not in after_classes + ] + + deleted_set = set(deleted) + + def has_deleted_ancestor(class_name: str) -> bool: + tokens = class_name.split(".") + for i in range(1, len(tokens)): + if ".".join(tokens[:i]) in deleted_set: + return True + return False + + for name in deleted: + if not has_deleted_ancestor(name): + # Align with FunctionDeleted's use of line=1 + yield api.violations.ClassDeleted(func=name, line=1) def _check_type_compatibility( diff --git a/tools/stronghold/src/api/violations.py b/tools/stronghold/src/api/violations.py index 53342d7d84..48db8da80e 100644 --- a/tools/stronghold/src/api/violations.py +++ b/tools/stronghold/src/api/violations.py @@ -123,3 +123,47 @@ def __post_init__(self) -> None: self.message = ( f"{self.parameter} changed from {self.type_before} to {self.type_after}" ) + + +# ==================================== +# Class field violations +@dataclass +class FieldViolation(Violation): + parameter: str = "" + + +@dataclass +class FieldRemoved(FieldViolation): + message: str = "" + + def __post_init__(self) -> None: + self.message = f"{self.parameter} was removed" + + +@dataclass +class FieldAdded(FieldViolation): + message: str = "" + + def __post_init__(self) -> None: + self.message = f"{self.parameter} was added" + + +@dataclass +class FieldTypeChanged(FieldViolation): + type_before: str = "" + type_after: str = "" + message: str = "" + + def __post_init__(self) -> None: + self.message = ( + f"{self.parameter} changed from {self.type_before} to {self.type_after}" + ) + + +# ==================================== +# Class violations +@dataclass +class ClassDeleted(Violation): + """Represents a public class being deleted.""" + + message: str = "class deleted" diff --git a/tools/stronghold/tests/api/test_ast.py b/tools/stronghold/tests/api/test_ast.py index 9410ba877e..dfa762859b 100644 --- a/tools/stronghold/tests/api/test_ast.py +++ b/tools/stronghold/tests/api/test_ast.py @@ -1,5 +1,6 @@ """Tests the api.ast module.""" +import dataclasses import pathlib import api @@ -12,7 +13,7 @@ def test_extract_empty(tmp_path: pathlib.Path) -> None: def func() -> None: pass # pragma: no cover - funcs = api.ast.extract(source.make_file(tmp_path, func)) + funcs = api.ast.extract(source.make_file(tmp_path, func)).functions assert funcs == { "func": api.Parameters( parameters=[], variadic_args=False, variadic_kwargs=False, line=1 @@ -24,7 +25,7 @@ def test_extract_positional(tmp_path: pathlib.Path) -> None: def func(x: int, /) -> None: pass # pragma: no cover - funcs = api.ast.extract(source.make_file(tmp_path, func)) + funcs = api.ast.extract(source.make_file(tmp_path, func)).functions assert funcs == { "func": api.Parameters( parameters=[ @@ -48,7 +49,7 @@ def test_extract_positional_with_default(tmp_path: pathlib.Path) -> None: def func(x: int = 0, /) -> None: pass # pragma: no cover - funcs = api.ast.extract(source.make_file(tmp_path, func)) + funcs = api.ast.extract(source.make_file(tmp_path, func)).functions assert funcs == { "func": api.Parameters( parameters=[ @@ -72,7 +73,7 @@ def test_extract_flexible(tmp_path: pathlib.Path) -> None: def func(x: int) -> None: pass # pragma: no cover - funcs = api.ast.extract(source.make_file(tmp_path, func)) + funcs = api.ast.extract(source.make_file(tmp_path, func)).functions assert funcs == { "func": api.Parameters( parameters=[ @@ -96,7 +97,7 @@ def test_extract_flexible_with_default(tmp_path: pathlib.Path) -> None: def func(x: int = 0) -> None: pass # pragma: no cover - funcs = api.ast.extract(source.make_file(tmp_path, func)) + funcs = api.ast.extract(source.make_file(tmp_path, func)).functions assert funcs == { "func": api.Parameters( parameters=[ @@ -120,7 +121,7 @@ def test_extract_keyword(tmp_path: pathlib.Path) -> None: def func(*, x: int) -> None: pass # pragma: no cover - funcs = api.ast.extract(source.make_file(tmp_path, func)) + funcs = api.ast.extract(source.make_file(tmp_path, func)).functions assert funcs == { "func": api.Parameters( parameters=[ @@ -144,7 +145,7 @@ def test_extract_keyword_with_default(tmp_path: pathlib.Path) -> None: def func(*, x: int = 0) -> None: pass # pragma: no cover - funcs = api.ast.extract(source.make_file(tmp_path, func)) + funcs = api.ast.extract(source.make_file(tmp_path, func)).functions assert funcs == { "func": api.Parameters( parameters=[ @@ -168,7 +169,7 @@ def test_extract_variadic_args(tmp_path: pathlib.Path) -> None: def func(*args: int) -> None: pass # pragma: no cover - funcs = api.ast.extract(source.make_file(tmp_path, func)) + funcs = api.ast.extract(source.make_file(tmp_path, func)).functions assert funcs == { "func": api.Parameters( parameters=[], variadic_args=True, variadic_kwargs=False, line=1 @@ -180,7 +181,7 @@ def test_extract_variadic_kwargs(tmp_path: pathlib.Path) -> None: def func(**kwargs: int) -> None: pass # pragma: no cover - funcs = api.ast.extract(source.make_file(tmp_path, func)) + funcs = api.ast.extract(source.make_file(tmp_path, func)).functions assert funcs == { "func": api.Parameters( parameters=[], variadic_args=False, variadic_kwargs=True, line=1 @@ -193,7 +194,7 @@ class Class: def func(self, /) -> None: pass # pragma: no cover - funcs = api.ast.extract(source.make_file(tmp_path, Class)) + funcs = api.ast.extract(source.make_file(tmp_path, Class)).functions assert funcs == { "Class.func": api.Parameters( parameters=[ @@ -212,14 +213,74 @@ def func(self, /) -> None: } +def test_extract_dataclass(tmp_path: pathlib.Path) -> None: + @dataclasses.dataclass + class Class: + a: int + b: int = 1 + + classes = api.ast.extract( + source.make_file(tmp_path, Class), include_classes=True + ).classes + assert classes == { + "Class": api.Class( + fields=[ + api.Field( + name="a", + required=True, + line=3, + type_annotation=api.types.TypeName("int"), + ), + api.Field( + name="b", + required=False, + line=4, + type_annotation=api.types.TypeName("int"), + ), + ], + line=2, + dataclass=True, + ) + } + + def test_extract_comprehensive(tmp_path: pathlib.Path) -> None: class Class: + a: int + b: float = 1.0 + def func( self, a: int, /, b: float = 2, *args: int, c: int, **kwargs: int ) -> None: pass # pragma: no cover - funcs = api.ast.extract(source.make_file(tmp_path, Class)) + extract_api = api.ast.extract( + source.make_file(tmp_path, Class), include_classes=True + ) + funcs = extract_api.functions + classes = extract_api.classes + + assert classes == { + "Class": api.Class( + fields=[ + api.Field( + name="a", + required=True, + line=2, + type_annotation=api.types.TypeName("int"), + ), + api.Field( + name="b", + required=False, + line=3, + type_annotation=api.types.TypeName("float"), + ), + ], + line=1, + dataclass=False, + ) + } + assert funcs == { "Class.func": api.Parameters( parameters=[ @@ -228,14 +289,14 @@ def func( positional=True, keyword=False, required=True, - line=3, + line=6, ), api.Parameter( name="a", positional=True, keyword=False, required=True, - line=3, + line=6, type_annotation=api.types.TypeName("int"), ), api.Parameter( @@ -243,7 +304,7 @@ def func( positional=True, keyword=True, required=False, - line=3, + line=6, type_annotation=api.types.TypeName("float"), ), api.Parameter( @@ -251,12 +312,12 @@ def func( positional=False, keyword=True, required=True, - line=3, + line=6, type_annotation=api.types.TypeName("int"), ), ], variadic_args=True, variadic_kwargs=True, - line=2, + line=5, ) } diff --git a/tools/stronghold/tests/api/test_ast_param_types.py b/tools/stronghold/tests/api/test_ast_param_types.py index 063709de98..1d7ba84d61 100644 --- a/tools/stronghold/tests/api/test_ast_param_types.py +++ b/tools/stronghold/tests/api/test_ast_param_types.py @@ -13,7 +13,7 @@ def extract_parameter_types( tmp_path: pathlib.Path, ) -> List[Optional[api.types.TypeHint]]: """Extracts the parameter types from a function definition.""" - funcs = api.ast.extract(tmp_path) + funcs = api.ast.extract(tmp_path).functions if not funcs: return [] return [ diff --git a/tools/stronghold/tests/api/test_compatibility.py b/tools/stronghold/tests/api/test_compatibility.py index dc72bec361..f538cd4c37 100644 --- a/tools/stronghold/tests/api/test_compatibility.py +++ b/tools/stronghold/tests/api/test_compatibility.py @@ -51,7 +51,7 @@ def func(self, /) -> None: after = source.make_file(tmp_path, lambda: None) assert api.compatibility.check(before, after) == [ - api.violations.FunctionDeleted(func="Class.func", line=1) + api.violations.ClassDeleted(func="Class", line=1), ] @@ -227,6 +227,246 @@ def func(x: int = 0, /) -> None: # type: ignore[no-redef] assert api.compatibility.check(before, after) == [] +def test_parameter_annotation_removed_no_violation(tmp_path: pathlib.Path) -> None: + def func(x: int) -> None: + pass # pragma: no cover + + before = source.make_file(tmp_path, func) + + def func(x) -> None: # type: ignore[no-redef] + pass # pragma: no cover + + after = source.make_file(tmp_path, func) + + assert api.compatibility.check(before, after) == [] + + +def test_parameter_annotation_added_no_violation(tmp_path: pathlib.Path) -> None: + def func(x) -> None: + pass # pragma: no cover + + before = source.make_file(tmp_path, func) + + def func(x: int) -> None: # type: ignore[no-redef] + pass # pragma: no cover + + after = source.make_file(tmp_path, func) + + assert api.compatibility.check(before, after) == [] + + +def test_deleted_inner_class_only(tmp_path: pathlib.Path) -> None: + before = tmp_path / "before_inner_deleted.py" + before.write_text( + textwrap.dedent( + """ + class Outer: + class Inner: + pass + """ + ) + ) + + after = tmp_path / "after_inner_deleted.py" + after.write_text( + textwrap.dedent( + """ + class Outer: + pass + """ + ) + ) + + assert api.compatibility.check(before, after) == [ + api.violations.ClassDeleted(func="Outer.Inner", line=1) + ] + + +def test_deleted_outer_class_collapses_inner_deletions(tmp_path: pathlib.Path) -> None: + before = tmp_path / "before_outer_deleted.py" + before.write_text( + textwrap.dedent( + """ + class Outer: + class Inner: + pass + """ + ) + ) + + after = tmp_path / "after_outer_deleted.py" + after.write_text("") + + violations = api.compatibility.check(before, after) + deleted = sorted( + v.func for v in violations if isinstance(v, api.violations.ClassDeleted) + ) + assert deleted == ["Outer"] + + +def test_method_removed_only_no_class_deleted(tmp_path: pathlib.Path) -> None: + before = tmp_path / "before_method_removed.py" + before.write_text( + textwrap.dedent( + """ + class Class: + def m(self): + pass + """ + ) + ) + + after = tmp_path / "after_method_removed.py" + after.write_text( + textwrap.dedent( + """ + class Class: + pass + """ + ) + ) + + # Class remains; method deletion should be reported as FunctionDeleted + assert api.compatibility.check(before, after) == [ + api.violations.FunctionDeleted(func="Class.m", line=1) + ] + + +def test_class_renamed_emits_class_deleted(tmp_path: pathlib.Path) -> None: + before = tmp_path / "before_class_renamed.py" + before.write_text( + textwrap.dedent( + """ + class Class: + a = 1 + """ + ) + ) + + after = tmp_path / "after_class_renamed.py" + after.write_text( + textwrap.dedent( + """ + class Renamed: + a = 1 + """ + ) + ) + + assert api.compatibility.check(before, after) == [ + api.violations.ClassDeleted(func="Class", line=1) + ] + + +def test_dataclass_field_default_change_no_violation(tmp_path: pathlib.Path) -> None: + before = tmp_path / "before_dc_default.py" + before.write_text( + textwrap.dedent( + """ + import dataclasses + @dataclasses.dataclass + class Class: + a: int = 1 + """ + ) + ) + + after = tmp_path / "after_dc_default.py" + after.write_text( + textwrap.dedent( + """ + import dataclasses + @dataclasses.dataclass + class Class: + a: int = 2 + """ + ) + ) + + assert api.compatibility.check(before, after) == [] + + +def test_class_field_order_reordered_no_violation(tmp_path: pathlib.Path) -> None: + before = tmp_path / "before_field_order.py" + before.write_text( + textwrap.dedent( + """ + class Class: + a = 1 + b = 2 + """ + ) + ) + + after = tmp_path / "after_field_order.py" + after.write_text( + textwrap.dedent( + """ + class Class: + b = 2 + a = 1 + """ + ) + ) + + assert api.compatibility.check(before, after) == [] + + +def test_nested_private_class_deleted_no_violation(tmp_path: pathlib.Path) -> None: + before = tmp_path / "before_nested_private_cls.py" + before.write_text( + textwrap.dedent( + """ + class Outer: + class _Inner: + pass + """ + ) + ) + + after = tmp_path / "after_nested_private_cls.py" + after.write_text( + textwrap.dedent( + """ + class Outer: + pass + """ + ) + ) + + assert api.compatibility.check(before, after) == [] + + +def test_dataclass_required_to_optional_field_no_violation( + tmp_path: pathlib.Path, +) -> None: + before = tmp_path / "before_dc_required_optional.py" + before.write_text( + textwrap.dedent( + """ + import dataclasses + @dataclasses.dataclass + class Class: + a: int + """ + ) + ) + + after = tmp_path / "after_dc_required_optional.py" + after.write_text( + textwrap.dedent( + """ + import dataclasses + @dataclasses.dataclass + class Class: + a: int = 1 + """ + ) + ) + + assert api.compatibility.check(before, after) == [] + + def test_new_optional_flexible_parameter(tmp_path: pathlib.Path) -> None: def func() -> None: pass # pragma: no cover @@ -520,3 +760,306 @@ def will_be_deleted(): api.violations.FunctionDeleted(func="will_be_deleted", line=1) ], } + + +def test_class_field_removed(tmp_path: pathlib.Path) -> None: + before = tmp_path / "before_cls.py" + before.write_text( + textwrap.dedent( + """ + class Class: + a = 1 + b = 2 + """ + ) + ) + + after = tmp_path / "after_cls.py" + after.write_text( + textwrap.dedent( + """ + class Class: + a = 1 + """ + ) + ) + + assert api.compatibility.check(before, after) == [ + api.violations.FieldRemoved(func="Class", parameter="b", line=2) + ] + + +def test_dataclass_field_removed(tmp_path: pathlib.Path) -> None: + before = tmp_path / "before.py" + before.write_text( + textwrap.dedent( + """ + @dataclasses.dataclass + class Class: + a: int + b: int + """ + ) + ) + + after = tmp_path / "after.py" + after.write_text( + textwrap.dedent( + """ + @dataclasses.dataclass + class Class: + a: int + """ + ) + ) + + assert api.compatibility.check(before, after) == [ + api.violations.FieldRemoved(func="Class", parameter="b", line=3) + ] + + +def test_dataclass_field_type_changed(tmp_path: pathlib.Path) -> None: + before = tmp_path / "before_type.py" + before.write_text( + textwrap.dedent( + """ + @dataclasses.dataclass + class Class: + a: int + """ + ) + ) + + after = tmp_path / "after_type.py" + after.write_text( + textwrap.dedent( + """ + @dataclasses.dataclass + class Class: + a: str + """ + ) + ) + + assert api.compatibility.check(before, after) == [ + api.violations.FieldTypeChanged( + func="Class", + parameter="a", + line=4, + type_before="int", + type_after="str", + ) + ] + + +def test_class_field_added(tmp_path: pathlib.Path) -> None: + before = tmp_path / "before_cls_add.py" + before.write_text( + textwrap.dedent( + """ + class Class: + a = 1 + """ + ) + ) + + after = tmp_path / "after_cls_add.py" + after.write_text( + textwrap.dedent( + """ + class Class: + a = 1 + b = 2 + """ + ) + ) + + # Adding a field to a regular class is not a BC violation + assert api.compatibility.check(before, after) == [] + + +def test_dataclass_field_added(tmp_path: pathlib.Path) -> None: + before = tmp_path / "before_dc_add.py" + before.write_text( + textwrap.dedent( + """ + @dataclasses.dataclass + class Class: + a: int + """ + ) + ) + + after = tmp_path / "after_dc_add.py" + after.write_text( + textwrap.dedent( + """ + @dataclasses.dataclass + class Class: + a: int + b: int + """ + ) + ) + + assert api.compatibility.check(before, after) == [ + api.violations.FieldAdded(func="Class", parameter="b", line=5) + ] + + +def test_dataclass_field_added_with_default_no_violation( + tmp_path: pathlib.Path, +) -> None: + before = tmp_path / "before_dc_add_default.py" + before.write_text( + textwrap.dedent( + """ + @dataclasses.dataclass + class Class: + a: int + """ + ) + ) + + after = tmp_path / "after_dc_add_default.py" + after.write_text( + textwrap.dedent( + """ + @dataclasses.dataclass + class Class: + a: int + b: int = 0 + """ + ) + ) + + assert api.compatibility.check(before, after) == [] + + +def test_dataclass_field_added_with_default_factory_no_violation( + tmp_path: pathlib.Path, +) -> None: + before = tmp_path / "before_dc_add_factory.py" + before.write_text( + textwrap.dedent( + """ + import dataclasses + @dataclasses.dataclass + class Class: + a: int + """ + ) + ) + + after = tmp_path / "after_dc_add_factory.py" + after.write_text( + textwrap.dedent( + """ + import dataclasses + @dataclasses.dataclass + class Class: + a: int + b: list[int] = dataclasses.field(default_factory=list) + """ + ) + ) + + assert api.compatibility.check(before, after) == [] + + +def test_dataclass_field_added_init_false_no_violation(tmp_path: pathlib.Path) -> None: + before = tmp_path / "before_dc_add_init_false.py" + before.write_text( + textwrap.dedent( + """ + import dataclasses + @dataclasses.dataclass + class Class: + a: int + """ + ) + ) + + after = tmp_path / "after_dc_add_init_false.py" + after.write_text( + textwrap.dedent( + """ + import dataclasses + @dataclasses.dataclass + class Class: + a: int + b: int = dataclasses.field(init=False, default=0) + """ + ) + ) + + assert api.compatibility.check(before, after) == [] + + +def test_class_deleted_violation(tmp_path: pathlib.Path) -> None: + before = tmp_path / "before_class_deleted.py" + before.write_text( + textwrap.dedent( + """ + class Class: + a = 1 + """ + ) + ) + + after = tmp_path / "after_class_deleted.py" + after.write_text("") + + assert api.compatibility.check(before, after) == [ + api.violations.ClassDeleted(func="Class", line=1) + ] + + +def test_private_class_field_changes_no_violation(tmp_path: pathlib.Path) -> None: + before = tmp_path / "before_private_cls.py" + before.write_text( + textwrap.dedent( + """ + class Class: + _a = 1 + """ + ) + ) + + after = tmp_path / "after_private_cls.py" + after.write_text( + textwrap.dedent( + """ + class Class: + _a = 2 + """ + ) + ) + + assert api.compatibility.check(before, after) == [] + + +def test_private_dataclass_field_changes_no_violation(tmp_path: pathlib.Path) -> None: + before = tmp_path / "before_private_dc.py" + before.write_text( + textwrap.dedent( + """ + @dataclasses.dataclass + class Class: + _a: int + """ + ) + ) + + after = tmp_path / "after_private_dc.py" + after.write_text( + textwrap.dedent( + """ + @dataclasses.dataclass + class Class: + _a: str + """ + ) + ) + + assert api.compatibility.check(before, after) == []