diff --git a/tools/stronghold/src/api/ast.py b/tools/stronghold/src/api/ast.py index 51a6a57d51..00d63e3537 100644 --- a/tools/stronghold/src/api/ast.py +++ b/tools/stronghold/src/api/ast.py @@ -19,21 +19,28 @@ def extract(path: pathlib.Path) -> Mapping[str, api.Parameters]: * ClassName.method_name * ClassName.SubClassName.method_name """ - raw_api = extract_raw(path) - return { - name: _function_def_to_parameters(function_def) - for name, function_def in raw_api.items() - } + api_map, _ = extract_all(path) + return api_map def extract_raw(path: pathlib.Path) -> Mapping[str, ast.FunctionDef]: """Extracts the API as ast.FunctionDef instances.""" + _, raw = extract_all(path) + return raw + +def extract_all( + path: pathlib.Path, +) -> tuple[Mapping[str, api.Parameters], Mapping[str, ast.FunctionDef]]: + """Extracts both parsed parameters and raw ``ast.FunctionDef`` nodes.""" out: dict[str, ast.FunctionDef] = {} _ContextualNodeVisitor(out, context=[]).visit( ast.parse(path.read_text(), os.fspath(path)) ) - return out - + api_map = { + name: _function_def_to_parameters(function_def) + for name, function_def in out.items() + } + return api_map, out def _function_def_to_parameters(node: ast.FunctionDef) -> api.Parameters: """Converts an ast.FunctionDef to api.Parameters.""" diff --git a/tools/stronghold/src/api/compatibility.py b/tools/stronghold/src/api/compatibility.py index 5dcfc488a2..f23d818253 100644 --- a/tools/stronghold/src/api/compatibility.py +++ b/tools/stronghold/src/api/compatibility.py @@ -6,6 +6,7 @@ import pathlib import tempfile from collections.abc import Iterable, Mapping, Sequence +import ast import api import api.ast @@ -68,13 +69,25 @@ def check( before: pathlib.Path, after: pathlib.Path ) -> Sequence[api.violations.Violation]: """Identifies API compatibility issues between two files.""" - before_api = api.ast.extract(before) - after_api = api.ast.extract(after) + before_api, before_raw = api.ast.extract_all(before) + after_api, after_raw = api.ast.extract_all(after) + + disabled_funcs = { + name + for name, node in before_raw.items() + if _decorator_disables(node) + } | { + name + for name, node in after_raw.items() + if _decorator_disables(node) + } violations: list[api.violations.Violation] = [] for name, before_def in before_api.items(): if any(token.startswith("_") for token in name.split(".")): continue + if name in disabled_funcs: + continue after_def = after_api.get(name) if after_def is None: @@ -320,3 +333,45 @@ def _check_type_compatibility( return False return True + + +def _decorator_disables(node: ast.FunctionDef) -> bool: + """Returns True if the bc_linter.check_compat decorator disables checks.""" + + for deco in node.decorator_list: + name = _decorator_name(deco) + if name == "bc_linter.skip": + return True + if name != "bc_linter.check_compat": + continue + + enable = True + if isinstance(deco, ast.Call): + # Look for keyword argument ``enable`` first + for kw in deco.keywords: + if kw.arg == "enable" and isinstance(kw.value, ast.Constant): + enable = bool(kw.value.value) + break + else: + if len(deco.args) == 1 and isinstance(deco.args[0], ast.Constant): + enable = bool(deco.args[0].value) + + return not enable + + return False + + +def _decorator_name(expr: ast.expr) -> str | None: + """Returns dotted name of decorator if easily determined.""" + + if isinstance(expr, ast.Call): + expr = expr.func + + if isinstance(expr, ast.Name): + return expr.id + if isinstance(expr, ast.Attribute): + value = _decorator_name(expr.value) + if value is None: + return None + return value + "." + expr.attr + return None diff --git a/tools/stronghold/tests/api/test_compatibility.py b/tools/stronghold/tests/api/test_compatibility.py index dc72bec361..0ade6b6abf 100644 --- a/tools/stronghold/tests/api/test_compatibility.py +++ b/tools/stronghold/tests/api/test_compatibility.py @@ -6,6 +6,7 @@ import api.violations import pytest from testing import git, source +import tests.bc_linter_example as bc_linter def test_deleted_function(tmp_path: pathlib.Path) -> None: @@ -520,3 +521,36 @@ def will_be_deleted(): api.violations.FunctionDeleted(func="will_be_deleted", line=1) ], } + + +def test_check_disable_decorator(tmp_path: pathlib.Path) -> None: + @bc_linter.skip + def func(x: int) -> None: + pass # pragma: no cover + + before = source.make_file(tmp_path, func) + + @bc_linter.skip + def func(x: int, y: int) -> None: # type: ignore[no-redef] + pass # pragma: no cover + + after = source.make_file(tmp_path, func) + + assert api.compatibility.check(before, after) == [] + +def test_check_enable_decorator(tmp_path: pathlib.Path) -> None: + @bc_linter.check_compat(enable=True) + def func(x: int) -> None: + pass # pragma: no cover + + before = source.make_file(tmp_path, func) + + @bc_linter.check_compat(enable=True) + def func(x: int, y: int) -> None: # type: ignore[no-redef] + pass # pragma: no cover + + after = source.make_file(tmp_path, func) + + assert api.compatibility.check(before, after) == [ + api.violations.ParameterNowRequired(func=func.__name__, parameter="y", line=2) + ] diff --git a/tools/stronghold/tests/bc_linter_example.py b/tools/stronghold/tests/bc_linter_example.py new file mode 100644 index 0000000000..f388c1f12c --- /dev/null +++ b/tools/stronghold/tests/bc_linter_example.py @@ -0,0 +1,25 @@ +"""Utilities for marking API compatibility checks.""" + +from __future__ import annotations + +from typing import Callable, TypeVar, Any + +F = TypeVar("F", bound=Callable[..., Any]) + + +def check_compat(*, enable: bool = True) -> Callable[[F], F]: + """Decorator used by stronghold to toggle API compatibility checks. + + When ``enable`` is ``False`` the decorated function will be skipped by the + backward compatibility linter. + """ + + def decorator(func: F) -> F: + # Not used in the linter, but useful for debugging. + setattr(func, "_bc_linter_enable", enable) + return func + + return decorator + +# Alias decorator to unconditionally disable the backward compatibility linter. +skip: Callable[[F], F] = check_compat(enable=False) diff --git a/tools/stronghold/tests/bc_linter_vllm.md b/tools/stronghold/tests/bc_linter_vllm.md new file mode 100644 index 0000000000..00eecc404c --- /dev/null +++ b/tools/stronghold/tests/bc_linter_vllm.md @@ -0,0 +1,13 @@ +# BC Linter for vLLM +PR: https://github.com/vllm-project/vllm/pull/21234 + +## Code Path +Cover the following code path: +- vllm/v1/attetion/** +- vllm/v1/core/** + +Additionally, we should have flexibility to cover other code path in the future. + +## Lint Rules +- Check backward compatibility for dataclasses/functions defined python files in code path above +- The default behavior for linter is to check all the dataclasses/public functions in the code path, but we provide an option to skip bc-linter for some experimental dataclasses/functions with `@bc_linter_skip` decorator