Skip to content

Commit 5382f4d

Browse files
zhewenlizaitsevfb
andauthored
[RFC] add class support in bc-linter (#6953)
as discussed with @izaitsevfb, we want to also include linter checks on classes in addition to public functions; but prior to landing this PR, we need to create some rules template for linters, where we can define what to check(or not). Like: ``` pytorch: - Files: global - Include: functions vLLM: - Files: vllm/v1/attrition, vllm/v1/core - Include, classes, functions ``` cc @huydhn @yangw-dev @houseroad --------- Co-authored-by: Ivan Zaitsev <[email protected]>
1 parent 8aee8b0 commit 5382f4d

File tree

7 files changed

+906
-49
lines changed

7 files changed

+906
-49
lines changed

tools/stronghold/src/api/__init__.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
import dataclasses
6-
from collections.abc import Sequence
6+
from collections.abc import Mapping, Sequence
77
from typing import Optional
88

99
import api.types
@@ -41,3 +41,32 @@ class Parameter:
4141
line: int
4242
# Type annotation (relies on ast.annotation types)
4343
type_annotation: Optional[api.types.TypeHint] = None
44+
45+
46+
@dataclasses.dataclass
47+
class Field:
48+
"""Represents a dataclass or class attribute."""
49+
50+
name: str
51+
required: bool
52+
line: int
53+
type_annotation: Optional[api.types.TypeHint] = None
54+
# Whether the field participates in the dataclass __init__ (dataclasses only)
55+
init: bool = True
56+
57+
58+
@dataclasses.dataclass
59+
class Class:
60+
"""Represents a class or dataclass."""
61+
62+
fields: Sequence[Field]
63+
line: int
64+
dataclass: bool = False
65+
66+
67+
@dataclasses.dataclass
68+
class API:
69+
"""Represents extracted API information."""
70+
71+
functions: Mapping[str, Parameters]
72+
classes: Mapping[str, Class]

tools/stronghold/src/api/ast.py

Lines changed: 109 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,28 +11,27 @@
1111
import api.types
1212

1313

14-
def extract(path: pathlib.Path) -> Mapping[str, api.Parameters]:
15-
"""Extracts the API from a given source file.
16-
17-
The keys will be the fully-qualified path from the root of the module, e.g.
18-
* global_func
19-
* ClassName.method_name
20-
* ClassName.SubClassName.method_name
21-
"""
22-
raw_api = extract_raw(path)
23-
return {
24-
name: _function_def_to_parameters(function_def)
25-
for name, function_def in raw_api.items()
14+
def extract(path: pathlib.Path, *, include_classes: bool = False) -> api.API:
15+
"""Extracts API definitions from a given source file."""
16+
17+
funcs, classes = extract_raw(path, include_classes=include_classes)
18+
parameters = {
19+
name: _function_def_to_parameters(func) for name, func in funcs.items()
2620
}
21+
return api.API(functions=parameters, classes=classes)
22+
2723

24+
def extract_raw(
25+
path: pathlib.Path, *, include_classes: bool = False
26+
) -> tuple[Mapping[str, ast.FunctionDef], Mapping[str, api.Class]]:
27+
"""Extracts API as AST nodes."""
2828

29-
def extract_raw(path: pathlib.Path) -> Mapping[str, ast.FunctionDef]:
30-
"""Extracts the API as ast.FunctionDef instances."""
31-
out: dict[str, ast.FunctionDef] = {}
32-
_ContextualNodeVisitor(out, context=[]).visit(
29+
funcs: dict[str, ast.FunctionDef] = {}
30+
classes: dict[str, api.Class] = {}
31+
_ContextualNodeVisitor(funcs, classes if include_classes else None, []).visit(
3332
ast.parse(path.read_text(), os.fspath(path))
3433
)
35-
return out
34+
return funcs, classes
3635

3736

3837
def _function_def_to_parameters(node: ast.FunctionDef) -> api.Parameters:
@@ -90,20 +89,105 @@ def _function_def_to_parameters(node: ast.FunctionDef) -> api.Parameters:
9089

9190

9291
class _ContextualNodeVisitor(ast.NodeVisitor):
93-
"""NodeVisitor implementation that tracks which class, if any, it is a member of."""
94-
95-
def __init__(self, out: dict[str, ast.FunctionDef], context: Sequence[str]) -> None:
96-
self._out = out
97-
self._context = context
92+
"""NodeVisitor that collects functions and optionally classes."""
93+
94+
def __init__(
95+
self,
96+
functions: dict[str, ast.FunctionDef],
97+
classes: dict[str, api.Class] | None,
98+
context: Sequence[str],
99+
) -> None:
100+
self._functions = functions
101+
self._classes = classes
102+
self._context = list(context)
98103

99104
def visit_ClassDef(self, node: ast.ClassDef) -> None:
100105
# Recursively visit all nodes under this class, with the given
101106
# class name pushed onto a new context.
107+
if self._classes is not None:
108+
name = ".".join(self._context + [node.name])
109+
is_dataclass = any(
110+
(isinstance(dec, ast.Name) and dec.id == "dataclass")
111+
or (isinstance(dec, ast.Attribute) and dec.attr == "dataclass")
112+
for dec in node.decorator_list
113+
)
114+
fields: list[api.Field] = []
115+
for stmt in node.body:
116+
if isinstance(stmt, ast.AnnAssign) and isinstance(
117+
stmt.target, ast.Name
118+
):
119+
field_name = stmt.target.id
120+
if field_name.startswith("_"):
121+
continue
122+
required = stmt.value is None
123+
init = True
124+
# Support dataclasses.field(...)
125+
if isinstance(stmt.value, ast.Call):
126+
fn = stmt.value.func
127+
128+
def _is_field_func(f: ast.AST) -> bool:
129+
return (isinstance(f, ast.Name) and f.id == "field") or (
130+
isinstance(f, ast.Attribute) and f.attr == "field"
131+
)
132+
133+
if _is_field_func(fn):
134+
# default/default_factory imply not required
135+
has_default = any(
136+
isinstance(kw, ast.keyword)
137+
and kw.arg == "default"
138+
and kw.value is not None
139+
for kw in stmt.value.keywords
140+
)
141+
has_default_factory = any(
142+
isinstance(kw, ast.keyword)
143+
and kw.arg == "default_factory"
144+
and kw.value is not None
145+
for kw in stmt.value.keywords
146+
)
147+
required = not (has_default or has_default_factory)
148+
# init flag
149+
for kw in stmt.value.keywords:
150+
if isinstance(kw, ast.keyword) and kw.arg == "init":
151+
init = not (
152+
isinstance(kw.value, ast.Constant)
153+
and kw.value.value is False
154+
)
155+
break
156+
fields.append(
157+
api.Field(
158+
name=field_name,
159+
required=required,
160+
line=stmt.lineno,
161+
type_annotation=api.types.annotation_to_dataclass(
162+
stmt.annotation
163+
),
164+
init=init,
165+
)
166+
)
167+
elif isinstance(stmt, ast.Assign):
168+
for target in stmt.targets:
169+
if isinstance(target, ast.Name):
170+
field_name = target.id
171+
if field_name.startswith("_"):
172+
continue
173+
fields.append(
174+
api.Field(
175+
name=field_name,
176+
required=False,
177+
line=stmt.lineno,
178+
type_annotation=None,
179+
init=True,
180+
)
181+
)
182+
self._classes[name] = api.Class(
183+
fields=fields, line=node.lineno, dataclass=is_dataclass
184+
)
185+
102186
_ContextualNodeVisitor(
103-
self._out, list(self._context) + [node.name]
187+
self._functions, self._classes, self._context + [node.name]
104188
).generic_visit(node)
105189

106190
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
107191
# Records this function.
108-
name = ".".join(list(self._context) + [node.name])
109-
self._out[name] = node
192+
name = ".".join(self._context + [node.name])
193+
self._functions[name] = node

tools/stronghold/src/api/compatibility.py

Lines changed: 101 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,39 @@ def check(
6868
before: pathlib.Path, after: pathlib.Path
6969
) -> Sequence[api.violations.Violation]:
7070
"""Identifies API compatibility issues between two files."""
71-
before_api = api.ast.extract(before)
72-
after_api = api.ast.extract(after)
71+
before_api = api.ast.extract(before, include_classes=True)
72+
after_api = api.ast.extract(after, include_classes=True)
73+
before_funcs = before_api.functions
74+
after_funcs = after_api.functions
75+
before_classes = before_api.classes
76+
after_classes = after_api.classes
77+
78+
# Identify deleted classes to avoid double-reporting their methods as deleted
79+
deleted_classes = {
80+
name
81+
for name in before_classes
82+
if not any(token.startswith("_") for token in name.split("."))
83+
and name not in after_classes
84+
}
85+
86+
def _under_deleted_class(func_name: str) -> bool:
87+
tokens = func_name.split(".")
88+
prefix = []
89+
for t in tokens[:-1]: # exclude the function name itself
90+
prefix.append(t)
91+
if ".".join(prefix) in deleted_classes:
92+
return True
93+
return False
7394

7495
violations: list[api.violations.Violation] = []
75-
for name, before_def in before_api.items():
96+
for name, before_def in before_funcs.items():
7697
if any(token.startswith("_") for token in name.split(".")):
7798
continue
99+
if _under_deleted_class(name):
100+
# Will be reported as a class deletion instead
101+
continue
78102

79-
after_def = after_api.get(name)
103+
after_def = after_funcs.get(name)
80104
if after_def is None:
81105
violations.append(api.violations.FunctionDeleted(func=name, line=1))
82106
continue
@@ -103,6 +127,17 @@ def check(
103127
violations += _check_by_requiredness(name, before_def, after_def)
104128
violations += _check_variadic_parameters(name, before_def, after_def)
105129

130+
for name, before_class in before_classes.items():
131+
if any(token.startswith("_") for token in name.split(".")):
132+
continue
133+
after_class = after_classes.get(name)
134+
if after_class is None:
135+
continue
136+
violations += list(_check_class_fields(name, before_class, after_class))
137+
138+
# Classes deleted between before and after
139+
violations += list(_check_deleted_classes(before_classes, after_classes))
140+
106141
return violations
107142

108143

@@ -247,7 +282,68 @@ def _check_variadic_parameters(
247282
if before.variadic_args and not after.variadic_args:
248283
yield api.violations.VarArgsDeleted(func=func, line=after.line)
249284
if before.variadic_kwargs and not after.variadic_kwargs:
250-
yield api.violations.KwArgsDeleted(func, line=after.line)
285+
yield api.violations.KwArgsDeleted(func=func, line=after.line)
286+
287+
288+
def _check_class_fields(
289+
cls: str, before: api.Class, after: api.Class
290+
) -> Iterable[api.violations.Violation]:
291+
"""Checks class and dataclass field compatibility."""
292+
293+
before_fields = {f.name: f for f in before.fields}
294+
after_fields = {f.name: f for f in after.fields}
295+
296+
for name, before_field in before_fields.items():
297+
after_field = after_fields.get(name)
298+
if after_field is None:
299+
yield api.violations.FieldRemoved(func=cls, parameter=name, line=after.line)
300+
continue
301+
302+
if not _check_type_compatibility(
303+
before_field.type_annotation, after_field.type_annotation
304+
):
305+
yield api.violations.FieldTypeChanged(
306+
func=cls,
307+
parameter=name,
308+
line=after_field.line,
309+
type_before=str(before_field.type_annotation),
310+
type_after=str(after_field.type_annotation),
311+
)
312+
313+
for name in set(after_fields) - set(before_fields):
314+
# Policy: Only flag additions for dataclasses, and only when required and in __init__
315+
if after.dataclass:
316+
added = after_fields[name]
317+
if added.required and getattr(added, "init", True):
318+
yield api.violations.FieldAdded(
319+
func=cls, parameter=name, line=added.line
320+
)
321+
322+
323+
def _check_deleted_classes(
324+
before_classes: Mapping[str, api.Class], after_classes: Mapping[str, api.Class]
325+
) -> Iterable[api.violations.Violation]:
326+
"""Emits violations for classes deleted between before and after."""
327+
deleted = [
328+
name
329+
for name in before_classes
330+
if not any(token.startswith("_") for token in name.split("."))
331+
and name not in after_classes
332+
]
333+
334+
deleted_set = set(deleted)
335+
336+
def has_deleted_ancestor(class_name: str) -> bool:
337+
tokens = class_name.split(".")
338+
for i in range(1, len(tokens)):
339+
if ".".join(tokens[:i]) in deleted_set:
340+
return True
341+
return False
342+
343+
for name in deleted:
344+
if not has_deleted_ancestor(name):
345+
# Align with FunctionDeleted's use of line=1
346+
yield api.violations.ClassDeleted(func=name, line=1)
251347

252348

253349
def _check_type_compatibility(

tools/stronghold/src/api/violations.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,47 @@ def __post_init__(self) -> None:
123123
self.message = (
124124
f"{self.parameter} changed from {self.type_before} to {self.type_after}"
125125
)
126+
127+
128+
# ====================================
129+
# Class field violations
130+
@dataclass
131+
class FieldViolation(Violation):
132+
parameter: str = ""
133+
134+
135+
@dataclass
136+
class FieldRemoved(FieldViolation):
137+
message: str = ""
138+
139+
def __post_init__(self) -> None:
140+
self.message = f"{self.parameter} was removed"
141+
142+
143+
@dataclass
144+
class FieldAdded(FieldViolation):
145+
message: str = ""
146+
147+
def __post_init__(self) -> None:
148+
self.message = f"{self.parameter} was added"
149+
150+
151+
@dataclass
152+
class FieldTypeChanged(FieldViolation):
153+
type_before: str = ""
154+
type_after: str = ""
155+
message: str = ""
156+
157+
def __post_init__(self) -> None:
158+
self.message = (
159+
f"{self.parameter} changed from {self.type_before} to {self.type_after}"
160+
)
161+
162+
163+
# ====================================
164+
# Class violations
165+
@dataclass
166+
class ClassDeleted(Violation):
167+
"""Represents a public class being deleted."""
168+
169+
message: str = "class deleted"

0 commit comments

Comments
 (0)