From 85a2a3cfea7879098ea32da986ceabf2a0d3f8ff Mon Sep 17 00:00:00 2001 From: Adam Turner <9087854+aa-turner@users.noreply.github.com> Date: Tue, 23 Apr 2024 04:16:03 +0100 Subject: [PATCH 1/6] Define ``__eq__`` and ``__hash__`` methods for AST types --- sphinx/domains/c/_ast.py | 415 +++++++++++++++++++- sphinx/domains/cpp/_ast.py | 786 ++++++++++++++++++++++++++++++++++++- sphinx/util/cfamily.py | 45 ++- 3 files changed, 1239 insertions(+), 7 deletions(-) diff --git a/sphinx/domains/c/_ast.py b/sphinx/domains/c/_ast.py index 3a8e2a2a4cb..2ff3c270f19 100644 --- a/sphinx/domains/c/_ast.py +++ b/sphinx/domains/c/_ast.py @@ -46,7 +46,7 @@ def __init__(self, identifier: str) -> None: # ASTBaseBase already implements this method, # but specialising it here improves performance def __eq__(self, other: object) -> bool: - if type(other) is not ASTIdentifier: + if not isinstance(other, ASTIdentifier): return NotImplemented return self.identifier == other.identifier @@ -94,6 +94,14 @@ def __init__(self, names: list[ASTIdentifier], rooted: bool) -> None: self.names = names self.rooted = rooted + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTNestedName): + return NotImplemented + return self.names == other.names and self.rooted == other.rooted + + def __hash__(self) -> int: + return hash((self.names, self.rooted)) + @property def name(self) -> ASTNestedName: return self @@ -186,6 +194,14 @@ class ASTBooleanLiteral(ASTLiteral): def __init__(self, value: bool) -> None: self.value = value + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTBooleanLiteral): + return NotImplemented + return self.value == other.value + + def __hash__(self) -> int: + return hash(self.value) + def _stringify(self, transform: StringifyTransform) -> str: if self.value: return 'true' @@ -202,6 +218,14 @@ class ASTNumberLiteral(ASTLiteral): def __init__(self, data: str) -> None: self.data = data + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTNumberLiteral): + return NotImplemented + return self.data == other.data + + def __hash__(self) -> int: + return hash(self.data) + def _stringify(self, transform: StringifyTransform) -> str: return self.data @@ -221,6 +245,17 @@ def __init__(self, prefix: str, data: str) -> None: else: raise UnsupportedMultiCharacterCharLiteral(decoded) + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTCharLiteral): + return NotImplemented + return ( + self.prefix == other.prefix + and self.value == other.value + ) + + def __hash__(self) -> int: + return hash((self.prefix, self.value)) + def _stringify(self, transform: StringifyTransform) -> str: if self.prefix is None: return "'" + self.data + "'" @@ -237,6 +272,14 @@ class ASTStringLiteral(ASTLiteral): def __init__(self, data: str) -> None: self.data = data + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTStringLiteral): + return NotImplemented + return self.data == other.data + + def __hash__(self) -> int: + return hash(self.data) + def _stringify(self, transform: StringifyTransform) -> str: return self.data @@ -251,6 +294,14 @@ def __init__(self, name: ASTNestedName) -> None: # note: this class is basically to cast a nested name as an expression self.name = name + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTIdExpression): + return NotImplemented + return self.name == other.name + + def __hash__(self) -> int: + return hash(self.name) + def _stringify(self, transform: StringifyTransform) -> str: return transform(self.name) @@ -266,6 +317,14 @@ class ASTParenExpr(ASTExpression): def __init__(self, expr: ASTExpression) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTParenExpr): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return '(' + transform(self.expr) + ')' @@ -290,6 +349,14 @@ class ASTPostfixCallExpr(ASTPostfixOp): def __init__(self, lst: ASTParenExprList | ASTBracedInitList) -> None: self.lst = lst + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTPostfixCallExpr): + return NotImplemented + return self.lst == other.lst + + def __hash__(self) -> int: + return hash(self.lst) + def _stringify(self, transform: StringifyTransform) -> str: return transform(self.lst) @@ -302,6 +369,14 @@ class ASTPostfixArray(ASTPostfixOp): def __init__(self, expr: ASTExpression) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTPostfixArray): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return '[' + transform(self.expr) + ']' @@ -334,6 +409,14 @@ class ASTPostfixMemberOfPointer(ASTPostfixOp): def __init__(self, name: ASTNestedName) -> None: self.name = name + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTPostfixMemberOfPointer): + return NotImplemented + return self.name == other.name + + def __hash__(self) -> int: + return hash(self.name) + def _stringify(self, transform: StringifyTransform) -> str: return '->' + transform(self.name) @@ -348,6 +431,14 @@ def __init__(self, prefix: ASTExpression, postFixes: list[ASTPostfixOp]) -> None self.prefix = prefix self.postFixes = postFixes + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTPostfixExpr): + return NotImplemented + return self.prefix == other.prefix and self.postFixes == other.postFixes + + def __hash__(self) -> int: + return hash((self.prefix, self.postFixes)) + def _stringify(self, transform: StringifyTransform) -> str: return ''.join([transform(self.prefix), *(transform(p) for p in self.postFixes)]) @@ -366,6 +457,14 @@ def __init__(self, op: str, expr: ASTExpression) -> None: self.op = op self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTUnaryOpExpr): + return NotImplemented + return self.op == other.op and self.expr == other.expr + + def __hash__(self) -> int: + return hash((self.op, self.expr)) + def _stringify(self, transform: StringifyTransform) -> str: if self.op[0] in 'cn': return self.op + " " + transform(self.expr) @@ -386,6 +485,14 @@ class ASTSizeofType(ASTExpression): def __init__(self, typ: ASTType) -> None: self.typ = typ + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTSizeofType): + return NotImplemented + return self.typ == other.typ + + def __hash__(self) -> int: + return hash(self.typ) + def _stringify(self, transform: StringifyTransform) -> str: return "sizeof(" + transform(self.typ) + ")" @@ -401,6 +508,14 @@ class ASTSizeofExpr(ASTExpression): def __init__(self, expr: ASTExpression) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTSizeofExpr): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return "sizeof " + transform(self.expr) @@ -415,6 +530,14 @@ class ASTAlignofExpr(ASTExpression): def __init__(self, typ: ASTType) -> None: self.typ = typ + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTAlignofExpr): + return NotImplemented + return self.typ == other.typ + + def __hash__(self) -> int: + return hash(self.typ) + def _stringify(self, transform: StringifyTransform) -> str: return "alignof(" + transform(self.typ) + ")" @@ -434,6 +557,17 @@ def __init__(self, typ: ASTType, expr: ASTExpression) -> None: self.typ = typ self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTCastExpr): + return NotImplemented + return ( + self.typ == other.typ + and self.expr == other.expr + ) + + def __hash__(self) -> int: + return hash((self.typ, self.expr)) + def _stringify(self, transform: StringifyTransform) -> str: res = ['('] res.append(transform(self.typ)) @@ -456,6 +590,17 @@ def __init__(self, exprs: list[ASTExpression], ops: list[str]) -> None: self.exprs = exprs self.ops = ops + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTBinOpExpr): + return NotImplemented + return ( + self.exprs == other.exprs + and self.ops == other.ops + ) + + def __hash__(self) -> int: + return hash((self.exprs, self.ops)) + def _stringify(self, transform: StringifyTransform) -> str: res = [] res.append(transform(self.exprs[0])) @@ -487,6 +632,17 @@ def __init__(self, exprs: list[ASTExpression], ops: list[str]) -> None: self.exprs = exprs self.ops = ops + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTAssignmentExpr): + return NotImplemented + return ( + self.exprs == other.exprs + and self.ops == other.ops + ) + + def __hash__(self) -> int: + return hash((self.exprs, self.ops)) + def _stringify(self, transform: StringifyTransform) -> str: res = [] res.append(transform(self.exprs[0])) @@ -515,6 +671,14 @@ class ASTFallbackExpr(ASTExpression): def __init__(self, expr: str) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTFallbackExpr): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return self.expr @@ -539,6 +703,14 @@ def __init__(self, names: list[str]) -> None: assert len(names) != 0 self.names = names + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTrailingTypeSpecFundamental): + return NotImplemented + return self.names == other.names + + def __hash__(self) -> int: + return hash(self.names) + def _stringify(self, transform: StringifyTransform) -> str: return ' '.join(self.names) @@ -558,6 +730,17 @@ def __init__(self, prefix: str, nestedName: ASTNestedName) -> None: self.prefix = prefix self.nestedName = nestedName + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTrailingTypeSpecName): + return NotImplemented + return ( + self.prefix == other.prefix + and self.nestedName == other.nestedName + ) + + def __hash__(self) -> int: + return hash((self.prefix, self.nestedName)) + @property def name(self) -> ASTNestedName: return self.nestedName @@ -583,6 +766,14 @@ def __init__(self, arg: ASTTypeWithInit | None, ellipsis: bool = False) -> None: self.arg = arg self.ellipsis = ellipsis + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTFunctionParameter): + return NotImplemented + return self.arg == other.arg and self.ellipsis == other.ellipsis + + def __hash__(self) -> int: + return hash((self.arg, self.ellipsis)) + def get_id(self, version: int, objectType: str, symbol: Symbol) -> str: # the anchor will be our parent return symbol.parent.declaration.get_id(version, prefixed=False) @@ -607,6 +798,14 @@ def __init__(self, args: list[ASTFunctionParameter], attrs: ASTAttributeList) -> self.args = args self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTParameters): + return NotImplemented + return self.args == other.args and self.attrs == other.attrs + + def __hash__(self) -> int: + return hash((self.args, self.attrs)) + @property def function_params(self) -> list[ASTFunctionParameter]: return self.args @@ -674,6 +873,30 @@ def __init__(self, storage: str, threadLocal: str, inline: bool, self.const = const self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclSpecsSimple): + return NotImplemented + return ( + self.storage == other.storage + and self.threadLocal == other.threadLocal + and self.inline == other.inline + and self.restrict == other.restrict + and self.volatile == other.volatile + and self.const == other.const + and self.attrs == other.attrs + ) + + def __hash__(self) -> int: + return hash(( + self.storage, + self.threadLocal, + self.inline, + self.restrict, + self.volatile, + self.const, + self.attrs, + )) + def mergeWith(self, other: ASTDeclSpecsSimple) -> ASTDeclSpecsSimple: if not other: return self @@ -741,6 +964,24 @@ def __init__(self, outer: str, self.allSpecs = self.leftSpecs.mergeWith(self.rightSpecs) self.trailingTypeSpec = trailing + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclSpecs): + return NotImplemented + return ( + self.outer == other.outer + and self.leftSpecs == other.leftSpecs + and self.rightSpecs == other.rightSpecs + and self.trailingTypeSpec == other.trailingTypeSpec + ) + + def __hash__(self) -> int: + return hash(( + self.outer, + self.leftSpecs, + self.rightSpecs, + self.trailingTypeSpec, + )) + def _stringify(self, transform: StringifyTransform) -> str: res: list[str] = [] l = transform(self.leftSpecs) @@ -796,6 +1037,28 @@ def __init__(self, static: bool, const: bool, volatile: bool, restrict: bool, if size is not None: assert not vla + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTArray): + return NotImplemented + return ( + self.static == other.static + and self.const == other.const + and self.volatile == other.volatile + and self.restrict == other.restrict + and self.vla == other.vla + and self.size == other.size + ) + + def __hash__(self) -> int: + return hash(( + self.static, + self.const, + self.volatile, + self.restrict, + self.vla, + self.size, + )) + def _stringify(self, transform: StringifyTransform) -> str: el = [] if self.static: @@ -861,6 +1124,18 @@ def __init__(self, declId: ASTNestedName, self.arrayOps = arrayOps self.param = param + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaratorNameParam): + return NotImplemented + return ( + self.declId == other.declId + and self.arrayOps == other.arrayOps + and self.param == other.param + ) + + def __hash__(self) -> int: + return hash((self.declId, self.arrayOps, self.param)) + @property def name(self) -> ASTNestedName: return self.declId @@ -899,6 +1174,14 @@ def __init__(self, declId: ASTNestedName, size: ASTExpression) -> None: self.declId = declId self.size = size + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaratorNameBitField): + return NotImplemented + return self.declId == other.declId and self.size == other.size + + def __hash__(self) -> int: + return hash((self.declId, self.size)) + @property def name(self) -> ASTNestedName: return self.declId @@ -937,6 +1220,20 @@ def __init__(self, next: ASTDeclarator, restrict: bool, volatile: bool, const: b self.const = const self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaratorPtr): + return NotImplemented + return ( + self.next == other.next + and self.restrict == other.restrict + and self.volatile == other.volatile + and self.const == other.const + and self.attrs == other.attrs + ) + + def __hash__(self) -> int: + return hash((self.next, self.restrict, self.volatile, self.const, self.attrs)) + @property def name(self) -> ASTNestedName: return self.next.name @@ -1006,6 +1303,14 @@ def __init__(self, inner: ASTDeclarator, next: ASTDeclarator) -> None: self.next = next # TODO: we assume the name and params are in inner + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaratorParen): + return NotImplemented + return self.inner == other.inner and self.next == other.next + + def __hash__(self) -> int: + return hash((self.inner, self.next)) + @property def name(self) -> ASTNestedName: return self.inner.name @@ -1040,6 +1345,14 @@ class ASTParenExprList(ASTBaseParenExprList): def __init__(self, exprs: list[ASTExpression]) -> None: self.exprs = exprs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTParenExprList): + return NotImplemented + return self.exprs == other.exprs + + def __hash__(self) -> int: + return hash(self.exprs) + def _stringify(self, transform: StringifyTransform) -> str: exprs = [transform(e) for e in self.exprs] return '(%s)' % ', '.join(exprs) @@ -1064,6 +1377,14 @@ def __init__(self, exprs: list[ASTExpression], trailingComma: bool) -> None: self.exprs = exprs self.trailingComma = trailingComma + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTBracedInitList): + return NotImplemented + return self.exprs == other.exprs and self.trailingComma == other.trailingComma + + def __hash__(self) -> int: + return hash((self.exprs, self.trailingComma)) + def _stringify(self, transform: StringifyTransform) -> str: exprs = ', '.join(transform(e) for e in self.exprs) trailingComma = ',' if self.trailingComma else '' @@ -1092,6 +1413,14 @@ def __init__(self, value: ASTBracedInitList | ASTExpression, self.value = value self.hasAssign = hasAssign + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTInitializer): + return NotImplemented + return self.value == other.value and self.hasAssign == other.hasAssign + + def __hash__(self) -> int: + return hash((self.value, self.hasAssign)) + def _stringify(self, transform: StringifyTransform) -> str: val = transform(self.value) if self.hasAssign: @@ -1116,6 +1445,14 @@ def __init__(self, declSpecs: ASTDeclSpecs, decl: ASTDeclarator) -> None: self.declSpecs = declSpecs self.decl = decl + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTType): + return NotImplemented + return self.declSpecs == other.declSpecs and self.decl == other.decl + + def __hash__(self) -> int: + return hash((self.declSpecs, self.decl)) + @property def name(self) -> ASTNestedName: return self.decl.name @@ -1161,6 +1498,14 @@ def __init__(self, type: ASTType, init: ASTInitializer) -> None: self.type = type self.init = init + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTypeWithInit): + return NotImplemented + return self.type == other.type and self.init == other.init + + def __hash__(self) -> int: + return hash((self.type, self.init)) + @property def name(self) -> ASTNestedName: return self.type.name @@ -1190,6 +1535,18 @@ def __init__(self, arg: ASTNestedName | None, ellipsis: bool = False, self.ellipsis = ellipsis self.variadic = variadic + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTMacroParameter): + return NotImplemented + return ( + self.arg == other.arg + and self.ellipsis == other.ellipsis + and self.variadic == other.variadic + ) + + def __hash__(self) -> int: + return hash((self.arg, self.ellipsis, self.variadic)) + def _stringify(self, transform: StringifyTransform) -> str: if self.ellipsis: return '...' @@ -1215,6 +1572,14 @@ def __init__(self, ident: ASTNestedName, args: list[ASTMacroParameter] | None) - self.ident = ident self.args = args + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTMacro): + return NotImplemented + return self.ident == other.ident and self.args == other.args + + def __hash__(self) -> int: + return hash((self.ident, self.args)) + @property def name(self) -> ASTNestedName: return self.ident @@ -1254,6 +1619,14 @@ class ASTStruct(ASTBase): def __init__(self, name: ASTNestedName) -> None: self.name = name + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTStruct): + return NotImplemented + return self.name == other.name + + def __hash__(self) -> int: + return hash(self.name) + def get_id(self, version: int, objectType: str, symbol: Symbol) -> str: return symbol.get_full_nested_name().get_id(version) @@ -1270,6 +1643,14 @@ class ASTUnion(ASTBase): def __init__(self, name: ASTNestedName) -> None: self.name = name + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTUnion): + return NotImplemented + return self.name == other.name + + def __hash__(self) -> int: + return hash(self.name) + def get_id(self, version: int, objectType: str, symbol: Symbol) -> str: return symbol.get_full_nested_name().get_id(version) @@ -1286,6 +1667,14 @@ class ASTEnum(ASTBase): def __init__(self, name: ASTNestedName) -> None: self.name = name + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTEnum): + return NotImplemented + return self.name == other.name + + def __hash__(self) -> int: + return hash(self.name) + def get_id(self, version: int, objectType: str, symbol: Symbol) -> str: return symbol.get_full_nested_name().get_id(version) @@ -1305,6 +1694,18 @@ def __init__(self, name: ASTNestedName, init: ASTInitializer | None, self.init = init self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTEnumerator): + return NotImplemented + return ( + self.name == other.name + and self.init == other.init + and self.attrs == other.attrs + ) + + def __hash__(self) -> int: + return hash((self.name, self.init, self.attrs)) + def get_id(self, version: int, objectType: str, symbol: Symbol) -> str: return symbol.get_full_nested_name().get_id(version) @@ -1346,6 +1747,18 @@ def __init__(self, objectType: str, directiveType: str | None, # further changes will be made to this object self._newest_id_cache: str | None = None + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaration): + return NotImplemented + return ( + self.objectType == other.objectType + and self.directiveType == other.directiveType + and self.declaration == other.declaration + and self.semicolon == other.semicolon + and self.symbol == other.symbol + and self.enumeratorScopedSymbol == other.enumeratorScopedSymbol + ) + def clone(self) -> ASTDeclaration: return ASTDeclaration(self.objectType, self.directiveType, self.declaration.clone(), self.semicolon) diff --git a/sphinx/domains/cpp/_ast.py b/sphinx/domains/cpp/_ast.py index ad57695d12f..579e330ebbd 100644 --- a/sphinx/domains/cpp/_ast.py +++ b/sphinx/domains/cpp/_ast.py @@ -52,10 +52,13 @@ def __init__(self, identifier: str) -> None: # ASTBaseBase already implements this method, # but specialising it here improves performance def __eq__(self, other: object) -> bool: - if type(other) is not ASTIdentifier: + if not isinstance(other, ASTIdentifier): return NotImplemented return self.identifier == other.identifier + def __hash__(self) -> int: + return hash(self.identifier) + def _stringify(self, transform: StringifyTransform) -> str: return transform(self.identifier) @@ -137,6 +140,14 @@ def __init__(self, identOrOp: ASTIdentifier | ASTOperator, self.identOrOp = identOrOp self.templateArgs = templateArgs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTNestedNameElement): + return NotImplemented + return self.identOrOp == other.identOrOp and self.templateArgs == other.templateArgs + + def __hash__(self) -> int: + return hash((self.identOrOp, self.templateArgs)) + def is_operator(self) -> bool: return False @@ -169,6 +180,18 @@ def __init__(self, names: list[ASTNestedNameElement], assert len(self.names) == len(self.templates) self.rooted = rooted + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTNestedName): + return NotImplemented + return ( + self.names == other.names + and self.templates == other.templates + and self.rooted == other.rooted + ) + + def __hash__(self) -> int: + return hash((self.names, self.templates, self.rooted)) + @property def name(self) -> ASTNestedName: return self @@ -316,6 +339,12 @@ class ASTLiteral(ASTExpression): class ASTPointerLiteral(ASTLiteral): + def __eq__(self, other: object) -> bool: + return isinstance(other, ASTPointerLiteral) + + def __hash__(self) -> int: + return hash('nullptr') + def _stringify(self, transform: StringifyTransform) -> str: return 'nullptr' @@ -331,6 +360,14 @@ class ASTBooleanLiteral(ASTLiteral): def __init__(self, value: bool) -> None: self.value = value + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTBooleanLiteral): + return NotImplemented + return self.value == other.value + + def __hash__(self) -> int: + return hash(self.value) + def _stringify(self, transform: StringifyTransform) -> str: if self.value: return 'true' @@ -352,6 +389,14 @@ class ASTNumberLiteral(ASTLiteral): def __init__(self, data: str) -> None: self.data = data + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTNumberLiteral): + return NotImplemented + return self.data == other.data + + def __hash__(self) -> int: + return hash(self.data) + def _stringify(self, transform: StringifyTransform) -> str: return self.data @@ -368,6 +413,14 @@ class ASTStringLiteral(ASTLiteral): def __init__(self, data: str) -> None: self.data = data + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTStringLiteral): + return NotImplemented + return self.data == other.data + + def __hash__(self) -> int: + return hash(self.data) + def _stringify(self, transform: StringifyTransform) -> str: return self.data @@ -392,6 +445,17 @@ def __init__(self, prefix: str, data: str) -> None: else: raise UnsupportedMultiCharacterCharLiteral(decoded) + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTCharLiteral): + return NotImplemented + return ( + self.prefix == other.prefix + and self.value == other.value + ) + + def __hash__(self) -> int: + return hash((self.prefix, self.value)) + def _stringify(self, transform: StringifyTransform) -> str: if self.prefix is None: return "'" + self.data + "'" @@ -415,6 +479,14 @@ def __init__(self, literal: ASTLiteral, ident: ASTIdentifier) -> None: self.literal = literal self.ident = ident + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTUserDefinedLiteral): + return NotImplemented + return self.literal == other.literal and self.ident == other.ident + + def __hash__(self) -> int: + return hash((self.literal, self.ident)) + def _stringify(self, transform: StringifyTransform) -> str: return transform(self.literal) + transform(self.ident) @@ -431,6 +503,12 @@ def describe_signature(self, signode: TextElement, mode: str, ################################################################################ class ASTThisLiteral(ASTExpression): + def __eq__(self, other: object) -> bool: + return isinstance(other, ASTThisLiteral) + + def __hash__(self) -> int: + return hash("this") + def _stringify(self, transform: StringifyTransform) -> str: return "this" @@ -450,6 +528,18 @@ def __init__(self, leftExpr: ASTExpression | None, self.op = op self.rightExpr = rightExpr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTFoldExpr): + return NotImplemented + return ( + self.leftExpr == other.leftExpr + and self.op == other.op + and self.rightExpr == other.rightExpr + ) + + def __hash__(self) -> int: + return hash((self.leftExpr, self.op, self.rightExpr)) + def _stringify(self, transform: StringifyTransform) -> str: res = ['('] if self.leftExpr: @@ -508,6 +598,14 @@ class ASTParenExpr(ASTExpression): def __init__(self, expr: ASTExpression) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTParenExpr): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return '(' + transform(self.expr) + ')' @@ -526,6 +624,14 @@ def __init__(self, name: ASTNestedName) -> None: # note: this class is basically to cast a nested name as an expression self.name = name + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTIdExpression): + return NotImplemented + return self.name == other.name + + def __hash__(self) -> int: + return hash(self.name) + def _stringify(self, transform: StringifyTransform) -> str: return transform(self.name) @@ -553,6 +659,14 @@ class ASTPostfixArray(ASTPostfixOp): def __init__(self, expr: ASTExpression) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTPostfixArray): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return '[' + transform(self.expr) + ']' @@ -570,6 +684,14 @@ class ASTPostfixMember(ASTPostfixOp): def __init__(self, name: ASTNestedName) -> None: self.name = name + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTPostfixMember): + return NotImplemented + return self.name == other.name + + def __hash__(self) -> int: + return hash(self.name) + def _stringify(self, transform: StringifyTransform) -> str: return '.' + transform(self.name) @@ -586,6 +708,14 @@ class ASTPostfixMemberOfPointer(ASTPostfixOp): def __init__(self, name: ASTNestedName) -> None: self.name = name + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTPostfixMemberOfPointer): + return NotImplemented + return self.name == other.name + + def __hash__(self) -> int: + return hash(self.name) + def _stringify(self, transform: StringifyTransform) -> str: return '->' + transform(self.name) @@ -599,6 +729,12 @@ def describe_signature(self, signode: TextElement, mode: str, class ASTPostfixInc(ASTPostfixOp): + def __eq__(self, other: object) -> bool: + return isinstance(other, ASTPostfixInc) + + def __hash__(self) -> int: + return hash('++') + def _stringify(self, transform: StringifyTransform) -> str: return '++' @@ -611,6 +747,12 @@ def describe_signature(self, signode: TextElement, mode: str, class ASTPostfixDec(ASTPostfixOp): + def __eq__(self, other: object) -> bool: + return isinstance(other, ASTPostfixDec) + + def __hash__(self) -> int: + return hash('--') + def _stringify(self, transform: StringifyTransform) -> str: return '--' @@ -626,6 +768,14 @@ class ASTPostfixCallExpr(ASTPostfixOp): def __init__(self, lst: ASTParenExprList | ASTBracedInitList) -> None: self.lst = lst + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTPostfixCallExpr): + return NotImplemented + return self.lst == other.lst + + def __hash__(self) -> int: + return hash(self.lst) + def _stringify(self, transform: StringifyTransform) -> str: return transform(self.lst) @@ -647,6 +797,14 @@ def __init__(self, prefix: ASTType, postFixes: list[ASTPostfixOp]) -> None: self.prefix = prefix self.postFixes = postFixes + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTPostfixExpr): + return NotImplemented + return self.prefix == other.prefix and self.postFixes == other.postFixes + + def __hash__(self) -> int: + return hash((self.prefix, self.postFixes)) + def _stringify(self, transform: StringifyTransform) -> str: return ''.join([transform(self.prefix), *(transform(p) for p in self.postFixes)]) @@ -670,6 +828,14 @@ def __init__(self, cast: str, typ: ASTType, expr: ASTExpression) -> None: self.typ = typ self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTExplicitCast): + return NotImplemented + return self.cast == other.cast and self.typ == other.typ and self.expr == other.expr + + def __hash__(self) -> int: + return hash((self.cast, self.typ, self.expr)) + def _stringify(self, transform: StringifyTransform) -> str: res = [self.cast] res.append('<') @@ -700,6 +866,14 @@ def __init__(self, typeOrExpr: ASTType | ASTExpression, isType: bool) -> None: self.typeOrExpr = typeOrExpr self.isType = isType + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTypeId): + return NotImplemented + return self.typeOrExpr == other.typeOrExpr and self.isType == other.isType + + def __hash__(self) -> int: + return hash((self.typeOrExpr, self.isType)) + def _stringify(self, transform: StringifyTransform) -> str: return 'typeid(' + transform(self.typeOrExpr) + ')' @@ -723,6 +897,14 @@ def __init__(self, op: str, expr: ASTExpression) -> None: self.op = op self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTUnaryOpExpr): + return NotImplemented + return self.op == other.op and self.expr == other.expr + + def __hash__(self) -> int: + return hash((self.op, self.expr)) + def _stringify(self, transform: StringifyTransform) -> str: if self.op[0] in 'cn': return self.op + " " + transform(self.expr) @@ -746,6 +928,14 @@ class ASTSizeofParamPack(ASTExpression): def __init__(self, identifier: ASTIdentifier) -> None: self.identifier = identifier + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTSizeofParamPack): + return NotImplemented + return self.identifier == other.identifier + + def __hash__(self) -> int: + return hash(self.identifier) + def _stringify(self, transform: StringifyTransform) -> str: return "sizeof...(" + transform(self.identifier) + ")" @@ -766,6 +956,14 @@ class ASTSizeofType(ASTExpression): def __init__(self, typ: ASTType) -> None: self.typ = typ + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTSizeofType): + return NotImplemented + return self.typ == other.typ + + def __hash__(self) -> int: + return hash(self.typ) + def _stringify(self, transform: StringifyTransform) -> str: return "sizeof(" + transform(self.typ) + ")" @@ -784,6 +982,14 @@ class ASTSizeofExpr(ASTExpression): def __init__(self, expr: ASTExpression) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTSizeofExpr): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return "sizeof " + transform(self.expr) @@ -801,6 +1007,14 @@ class ASTAlignofExpr(ASTExpression): def __init__(self, typ: ASTType) -> None: self.typ = typ + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTAlignofExpr): + return NotImplemented + return self.typ == other.typ + + def __hash__(self) -> int: + return hash(self.typ) + def _stringify(self, transform: StringifyTransform) -> str: return "alignof(" + transform(self.typ) + ")" @@ -819,6 +1033,14 @@ class ASTNoexceptExpr(ASTExpression): def __init__(self, expr: ASTExpression) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTNoexceptExpr): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return 'noexcept(' + transform(self.expr) + ')' @@ -841,6 +1063,19 @@ def __init__(self, rooted: bool, isNewTypeId: bool, typ: ASTType, self.typ = typ self.initList = initList + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTNewExpr): + return NotImplemented + return ( + self.rooted == other.rooted + and self.isNewTypeId == other.isNewTypeId + and self.typ == other.typ + and self.initList == other.initList + ) + + def __hash__(self) -> int: + return hash((self.rooted, self.isNewTypeId, self.typ, self.initList)) + def _stringify(self, transform: StringifyTransform) -> str: res = [] if self.rooted: @@ -888,6 +1123,18 @@ def __init__(self, rooted: bool, array: bool, expr: ASTExpression) -> None: self.array = array self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeleteExpr): + return NotImplemented + return ( + self.rooted == other.rooted + and self.array == other.array + and self.expr == other.expr + ) + + def __hash__(self) -> int: + return hash((self.rooted, self.array, self.expr)) + def _stringify(self, transform: StringifyTransform) -> str: res = [] if self.rooted: @@ -925,6 +1172,17 @@ def __init__(self, typ: ASTType, expr: ASTExpression) -> None: self.typ = typ self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTCastExpr): + return NotImplemented + return ( + self.typ == other.typ + and self.expr == other.expr + ) + + def __hash__(self) -> int: + return hash((self.typ, self.expr)) + def _stringify(self, transform: StringifyTransform) -> str: res = ['('] res.append(transform(self.typ)) @@ -950,6 +1208,17 @@ def __init__(self, exprs: list[ASTExpression], ops: list[str]) -> None: self.exprs = exprs self.ops = ops + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTBinOpExpr): + return NotImplemented + return ( + self.exprs == other.exprs + and self.ops == other.ops + ) + + def __hash__(self) -> int: + return hash((self.exprs, self.ops)) + def _stringify(self, transform: StringifyTransform) -> str: res = [] res.append(transform(self.exprs[0])) @@ -990,6 +1259,18 @@ def __init__(self, ifExpr: ASTExpression, thenExpr: ASTExpression, self.thenExpr = thenExpr self.elseExpr = elseExpr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTConditionalExpr): + return NotImplemented + return ( + self.ifExpr == other.ifExpr + and self.thenExpr == other.thenExpr + and self.elseExpr == other.elseExpr + ) + + def __hash__(self) -> int: + return hash((self.ifExpr, self.thenExpr, self.elseExpr)) + def _stringify(self, transform: StringifyTransform) -> str: res = [] res.append(transform(self.ifExpr)) @@ -1027,6 +1308,14 @@ def __init__(self, exprs: list[ASTExpression | ASTBracedInitList], self.exprs = exprs self.trailingComma = trailingComma + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTBracedInitList): + return NotImplemented + return self.exprs == other.exprs and self.trailingComma == other.trailingComma + + def __hash__(self) -> int: + return hash((self.exprs, self.trailingComma)) + def get_id(self, version: int) -> str: return "il%sE" % ''.join(e.get_id(version) for e in self.exprs) @@ -1059,6 +1348,18 @@ def __init__(self, leftExpr: ASTExpression, op: str, self.op = op self.rightExpr = rightExpr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTAssignmentExpr): + return NotImplemented + return ( + self.leftExpr == other.leftExpr + and self.op == other.op + and self.rightExpr == other.rightExpr + ) + + def __hash__(self) -> int: + return hash((self.leftExpr, self.op, self.rightExpr)) + def _stringify(self, transform: StringifyTransform) -> str: res = [] res.append(transform(self.leftExpr)) @@ -1093,6 +1394,14 @@ def __init__(self, exprs: list[ASTExpression]) -> None: assert len(exprs) > 0 self.exprs = exprs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTCommaExpr): + return NotImplemented + return self.exprs == other.exprs + + def __hash__(self) -> int: + return hash(self.exprs) + def _stringify(self, transform: StringifyTransform) -> str: return ', '.join(transform(e) for e in self.exprs) @@ -1118,6 +1427,14 @@ class ASTFallbackExpr(ASTExpression): def __init__(self, expr: str) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTFallbackExpr): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return self.expr @@ -1140,6 +1457,9 @@ class ASTOperator(ASTBase): def __eq__(self, other: object) -> bool: raise NotImplementedError(repr(self)) + def __hash__(self) -> int: + raise NotImplementedError(repr(self)) + def is_anon(self) -> bool: return False @@ -1193,6 +1513,9 @@ def __eq__(self, other: object) -> bool: return NotImplemented return self.op == other.op + def __hash__(self) -> int: + return hash(self.op) + def get_id(self, version: int) -> str: if version == 1: ids = _id_operator_v1 @@ -1228,6 +1551,9 @@ def __eq__(self, other: object) -> bool: return NotImplemented return self.identifier == other.identifier + def __hash__(self) -> int: + return hash(self.identifier) + def get_id(self, version: int) -> str: if version == 1: raise NoOldIdError @@ -1252,6 +1578,9 @@ def __eq__(self, other: object) -> bool: return NotImplemented return self.type == other.type + def __hash__(self) -> int: + return hash(self.type) + def get_id(self, version: int) -> str: if version == 1: return 'castto-%s-operator' % self.type.get_id(version) @@ -1275,6 +1604,14 @@ class ASTTemplateArgConstant(ASTBase): def __init__(self, value: ASTExpression) -> None: self.value = value + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateArgConstant): + return NotImplemented + return self.value == other.value + + def __hash__(self) -> int: + return hash(self.value) + def _stringify(self, transform: StringifyTransform) -> str: return transform(self.value) @@ -1298,6 +1635,14 @@ def __init__(self, args: list[ASTType | ASTTemplateArgConstant], self.args = args self.packExpansion = packExpansion + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateArgs): + return NotImplemented + return self.args == other.args and self.packExpansion == other.packExpansion + + def __hash__(self) -> int: + return hash((self.args, self.packExpansion)) + def get_id(self, version: int) -> str: if version == 1: res = [] @@ -1361,6 +1706,14 @@ def __init__(self, names: list[str], canonNames: list[str]) -> None: # the canonical name list is for ID lookup self.canonNames = canonNames + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTrailingTypeSpecFundamental): + return NotImplemented + return self.names == other.names and self.canonNames == other.canonNames + + def __hash__(self) -> int: + return hash((self.names, self.canonNames)) + def _stringify(self, transform: StringifyTransform) -> str: return ' '.join(self.names) @@ -1394,6 +1747,12 @@ def describe_signature(self, signode: TextElement, mode: str, class ASTTrailingTypeSpecDecltypeAuto(ASTTrailingTypeSpec): + def __eq__(self, other: object) -> bool: + return isinstance(other, ASTTrailingTypeSpecDecltypeAuto) + + def __hash__(self) -> int: + return hash('decltype(auto)') + def _stringify(self, transform: StringifyTransform) -> str: return 'decltype(auto)' @@ -1414,6 +1773,14 @@ class ASTTrailingTypeSpecDecltype(ASTTrailingTypeSpec): def __init__(self, expr: ASTExpression) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTrailingTypeSpecDecltype): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return 'decltype(' + transform(self.expr) + ')' @@ -1437,6 +1804,18 @@ def __init__(self, prefix: str, nestedName: ASTNestedName, self.nestedName = nestedName self.placeholderType = placeholderType + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTrailingTypeSpecName): + return NotImplemented + return ( + self.prefix == other.prefix + and self.nestedName == other.nestedName + and self.placeholderType == other.placeholderType + ) + + def __hash__(self) -> int: + return hash((self.prefix, self.nestedName, self.placeholderType)) + @property def name(self) -> ASTNestedName: return self.nestedName @@ -1480,6 +1859,14 @@ def __init__(self, arg: ASTTypeWithInit | ASTTemplateParamConstrainedTypeWithIni self.arg = arg self.ellipsis = ellipsis + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTFunctionParameter): + return NotImplemented + return self.arg == other.arg and self.ellipsis == other.ellipsis + + def __hash__(self) -> int: + return hash((self.arg, self.ellipsis)) + def get_id( self, version: int, objectType: str | None = None, symbol: Symbol | None = None, ) -> str: @@ -1512,6 +1899,14 @@ class ASTNoexceptSpec(ASTBase): def __init__(self, expr: ASTExpression | None) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTNoexceptSpec): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: if self.expr: return 'noexcept(' + transform(self.expr) + ')' @@ -1543,6 +1938,28 @@ def __init__(self, args: list[ASTFunctionParameter], volatile: bool, const: bool self.attrs = attrs self.initializer = initializer + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTParametersQualifiers): + return NotImplemented + return ( + self.args == other.args + and self.volatile == other.volatile + and self.const == other.const + and self.refQual == other.refQual + and self.exceptionSpec == other.exceptionSpec + and self.trailingReturn == other.trailingReturn + and self.override == other.override + and self.final == other.final + and self.attrs == other.attrs + and self.initializer == other.initializer + ) + + def __hash__(self) -> int: + return hash(( + self.args, self.volatile, self.const, self.refQual, self.exceptionSpec, + self.trailingReturn, self.override, self.final, self.attrs, self.initializer + )) + @property def function_params(self) -> list[ASTFunctionParameter]: return self.args @@ -1681,6 +2098,14 @@ class ASTExplicitSpec(ASTBase): def __init__(self, expr: ASTExpression | None) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTExplicitSpec): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: res = ['explicit'] if self.expr is not None: @@ -1717,6 +2142,40 @@ def __init__(self, storage: str, threadLocal: bool, inline: bool, virtual: bool, self.friend = friend self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclSpecsSimple): + return NotImplemented + return ( + self.storage == other.storage + and self.threadLocal == other.threadLocal + and self.inline == other.inline + and self.virtual == other.virtual + and self.explicitSpec == other.explicitSpec + and self.consteval == other.consteval + and self.constexpr == other.constexpr + and self.constinit == other.constinit + and self.volatile == other.volatile + and self.const == other.const + and self.friend == other.friend + and self.attrs == other.attrs + ) + + def __hash__(self) -> int: + return hash(( + self.storage, + self.threadLocal, + self.inline, + self.virtual, + self.explicitSpec, + self.consteval, + self.constexpr, + self.constinit, + self.volatile, + self.const, + self.friend, + self.attrs, + )) + def mergeWith(self, other: ASTDeclSpecsSimple) -> ASTDeclSpecsSimple: if not other: return self @@ -1811,6 +2270,24 @@ def __init__(self, outer: str, self.allSpecs = self.leftSpecs.mergeWith(self.rightSpecs) self.trailingTypeSpec = trailing + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclSpecs): + return NotImplemented + return ( + self.outer == other.outer + and self.leftSpecs == other.leftSpecs + and self.rightSpecs == other.rightSpecs + and self.trailingTypeSpec == other.trailingTypeSpec + ) + + def __hash__(self) -> int: + return hash(( + self.outer, + self.leftSpecs, + self.rightSpecs, + self.trailingTypeSpec, + )) + def get_id(self, version: int) -> str: if version == 1: res = [] @@ -1873,6 +2350,14 @@ class ASTArray(ASTBase): def __init__(self, size: ASTExpression) -> None: self.size = size + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTArray): + return NotImplemented + return self.size == other.size + + def __hash__(self) -> int: + return hash(self.size) + def _stringify(self, transform: StringifyTransform) -> str: if self.size: return '[' + transform(self.size) + ']' @@ -1953,6 +2438,18 @@ def __init__(self, declId: ASTNestedName, self.arrayOps = arrayOps self.paramQual = paramQual + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaratorNameParamQual): + return NotImplemented + return ( + self.declId == other.declId + and self.arrayOps == other.arrayOps + and self.paramQual == other.paramQual + ) + + def __hash__(self) -> int: + return hash((self.declId, self.arrayOps, self.paramQual)) + @property def name(self) -> ASTNestedName: return self.declId @@ -2037,6 +2534,14 @@ def __init__(self, declId: ASTNestedName, size: ASTExpression) -> None: self.declId = declId self.size = size + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaratorNameBitField): + return NotImplemented + return self.declId == other.declId and self.size == other.size + + def __hash__(self) -> int: + return hash((self.declId, self.size)) + @property def name(self) -> ASTNestedName: return self.declId @@ -2087,6 +2592,19 @@ def __init__(self, next: ASTDeclarator, volatile: bool, const: bool, self.const = const self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaratorPtr): + return NotImplemented + return ( + self.next == other.next + and self.volatile == other.volatile + and self.const == other.const + and self.attrs == other.attrs + ) + + def __hash__(self) -> int: + return hash((self.next, self.volatile, self.const, self.attrs)) + @property def name(self) -> ASTNestedName: return self.next.name @@ -2192,6 +2710,14 @@ def __init__(self, next: ASTDeclarator, attrs: ASTAttributeList) -> None: self.next = next self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaratorRef): + return NotImplemented + return self.next == other.next and self.attrs == other.attrs + + def __hash__(self) -> int: + return hash((self.next, self.attrs)) + @property def name(self) -> ASTNestedName: return self.next.name @@ -2258,6 +2784,14 @@ def __init__(self, next: ASTDeclarator) -> None: assert next self.next = next + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaratorParamPack): + return NotImplemented + return self.next == other.next + + def __hash__(self) -> int: + return hash(self.next) + @property def name(self) -> ASTNestedName: return self.next.name @@ -2326,6 +2860,19 @@ def __init__(self, className: ASTNestedName, self.volatile = volatile self.next = next + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaratorMemPtr): + return NotImplemented + return ( + self.className == other.className + and self.const == other.const + and self.volatile == other.volatile + and self.next == other.next + ) + + def __hash__(self) -> int: + return hash((self.className, self.const, self.volatile, self.next)) + @property def name(self) -> ASTNestedName: return self.next.name @@ -2424,6 +2971,14 @@ def __init__(self, inner: ASTDeclarator, next: ASTDeclarator) -> None: self.next = next # TODO: we assume the name, params, and qualifiers are in inner + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaratorParen): + return NotImplemented + return self.inner == other.inner and self.next == other.next + + def __hash__(self) -> int: + return hash((self.inner, self.next)) + @property def name(self) -> ASTNestedName: return self.inner.name @@ -2493,6 +3048,14 @@ class ASTPackExpansionExpr(ASTExpression): def __init__(self, expr: ASTExpression | ASTBracedInitList) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTPackExpansionExpr): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return transform(self.expr) + '...' @@ -2510,6 +3073,14 @@ class ASTParenExprList(ASTBaseParenExprList): def __init__(self, exprs: list[ASTExpression | ASTBracedInitList]) -> None: self.exprs = exprs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTParenExprList): + return NotImplemented + return self.exprs == other.exprs + + def __hash__(self) -> int: + return hash(self.exprs) + def get_id(self, version: int) -> str: return "pi%sE" % ''.join(e.get_id(version) for e in self.exprs) @@ -2538,6 +3109,14 @@ def __init__(self, value: ASTExpression | ASTBracedInitList, self.value = value self.hasAssign = hasAssign + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTInitializer): + return NotImplemented + return self.value == other.value and self.hasAssign == other.hasAssign + + def __hash__(self) -> int: + return hash((self.value, self.hasAssign)) + def _stringify(self, transform: StringifyTransform) -> str: val = transform(self.value) if self.hasAssign: @@ -2562,6 +3141,14 @@ def __init__(self, declSpecs: ASTDeclSpecs, decl: ASTDeclarator) -> None: self.declSpecs = declSpecs self.decl = decl + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTType): + return NotImplemented + return self.declSpecs == other.declSpecs and self.decl == other.decl + + def __hash__(self) -> int: + return hash((self.declSpecs, self.decl)) + @property def name(self) -> ASTNestedName: return self.decl.name @@ -2671,6 +3258,14 @@ def __init__(self, type: ASTType, init: ASTType) -> None: self.type = type self.init = init + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateParamConstrainedTypeWithInit): + return NotImplemented + return self.type == other.type and self.init == other.init + + def __hash__(self) -> int: + return hash((self.type, self.init)) + @property def name(self) -> ASTNestedName: return self.type.name @@ -2712,6 +3307,14 @@ def __init__(self, type: ASTType, init: ASTInitializer) -> None: self.type = type self.init = init + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTypeWithInit): + return NotImplemented + return self.type == other.type and self.init == other.init + + def __hash__(self) -> int: + return hash((self.type, self.init)) + @property def name(self) -> ASTNestedName: return self.type.name @@ -2749,6 +3352,14 @@ def __init__(self, name: ASTNestedName, type: ASTType | None) -> None: self.name = name self.type = type + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTypeUsing): + return NotImplemented + return self.name == other.name and self.type == other.type + + def __hash__(self) -> int: + return hash((self.name, self.type)) + def get_id(self, version: int, objectType: str | None = None, symbol: Symbol | None = None) -> str: if version == 1: @@ -2785,6 +3396,14 @@ def __init__(self, nestedName: ASTNestedName, initializer: ASTInitializer) -> No self.nestedName = nestedName self.initializer = initializer + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTConcept): + return NotImplemented + return self.nestedName == other.nestedName and self.initializer == other.initializer + + def __hash__(self) -> int: + return hash((self.nestedName, self.initializer)) + @property def name(self) -> ASTNestedName: return self.nestedName @@ -2816,6 +3435,19 @@ def __init__(self, name: ASTNestedName, visibility: str, self.virtual = virtual self.pack = pack + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTBaseClass): + return NotImplemented + return ( + self.name == other.name + and self.visibility == other.visibility + and self.virtual == other.virtual + and self.pack == other.pack + ) + + def __hash__(self) -> int: + return hash((self.name, self.visibility, self.virtual, self.pack)) + def _stringify(self, transform: StringifyTransform) -> str: res = [] if self.visibility is not None: @@ -2851,6 +3483,19 @@ def __init__(self, name: ASTNestedName, final: bool, bases: list[ASTBaseClass], self.bases = bases self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTClass): + return NotImplemented + return ( + self.name == other.name + and self.final == other.final + and self.bases == other.bases + and self.attrs == other.attrs + ) + + def __hash__(self) -> int: + return hash((self.name, self.final, self.bases, self.attrs)) + def get_id(self, version: int, objectType: str, symbol: Symbol) -> str: return symbol.get_full_nested_name().get_id(version) @@ -2899,6 +3544,14 @@ def __init__(self, name: ASTNestedName, attrs: ASTAttributeList) -> None: self.name = name self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTUnion): + return NotImplemented + return self.name == other.name and self.attrs == other.attrs + + def __hash__(self) -> int: + return hash((self.name, self.attrs)) + def get_id(self, version: int, objectType: str, symbol: Symbol) -> str: if version == 1: raise NoOldIdError @@ -2929,6 +3582,19 @@ def __init__(self, name: ASTNestedName, scoped: str, underlyingType: ASTType, self.underlyingType = underlyingType self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTEnum): + return NotImplemented + return ( + self.name == other.name + and self.scoped == other.scoped + and self.underlyingType == other.underlyingType + and self.attrs == other.attrs + ) + + def __hash__(self) -> int: + return hash((self.name, self.scoped, self.underlyingType, self.attrs)) + def get_id(self, version: int, objectType: str, symbol: Symbol) -> str: if version == 1: raise NoOldIdError @@ -2971,6 +3637,18 @@ def __init__(self, name: ASTNestedName, init: ASTInitializer | None, self.init = init self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTEnumerator): + return NotImplemented + return ( + self.name == other.name + and self.init == other.init + and self.attrs == other.attrs + ) + + def __hash__(self) -> int: + return hash((self.name, self.init, self.attrs)) + def get_id(self, version: int, objectType: str, symbol: Symbol) -> str: if version == 1: raise NoOldIdError @@ -3035,6 +3713,19 @@ def __init__(self, key: str, identifier: ASTIdentifier, self.parameterPack = parameterPack self.default = default + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateKeyParamPackIdDefault): + return NotImplemented + return ( + self.key == other.key + and self.identifier == other.identifier + and self.parameterPack == other.parameterPack + and self.default == other.default + ) + + def __hash__(self) -> int: + return hash((self.key, self.identifier, self.parameterPack, self.default)) + def get_identifier(self) -> ASTIdentifier: return self.identifier @@ -3086,6 +3777,14 @@ def __init__(self, data: ASTTemplateKeyParamPackIdDefault) -> None: assert data self.data = data + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateParamType): + return NotImplemented + return self.data == other.data + + def __hash__(self) -> int: + return hash(self.data) + @property def name(self) -> ASTNestedName: id = self.get_identifier() @@ -3125,6 +3824,17 @@ def __init__(self, nestedParams: ASTTemplateParams, self.nestedParams = nestedParams self.data = data + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateParamTemplateType): + return NotImplemented + return ( + self.nestedParams == other.nestedParams + and self.data == other.data + ) + + def __hash__(self) -> int: + return hash((self.nestedParams, self.data)) + @property def name(self) -> ASTNestedName: id = self.get_identifier() @@ -3166,6 +3876,14 @@ def __init__(self, self.param = param self.parameterPack = parameterPack + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateParamNonType): + return NotImplemented + return ( + self.param == other.param + and self.parameterPack == other.parameterPack + ) + @property def name(self) -> ASTNestedName: id = self.get_identifier() @@ -3221,6 +3939,14 @@ def __init__(self, params: list[ASTTemplateParam], self.params = params self.requiresClause = requiresClause + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateParams): + return NotImplemented + return self.params == other.params and self.requiresClause == other.requiresClause + + def __hash__(self) -> int: + return hash((self.params, self.requiresClause)) + def get_id(self, version: int, excludeRequires: bool = False) -> str: assert version >= 2 res = [] @@ -3295,6 +4021,17 @@ def __init__(self, identifier: ASTIdentifier, parameterPack: bool) -> None: self.identifier = identifier self.parameterPack = parameterPack + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateIntroductionParameter): + return NotImplemented + return ( + self.identifier == other.identifier + and self.parameterPack == other.parameterPack + ) + + def __hash__(self) -> int: + return hash((self.identifier, self.parameterPack)) + @property def name(self) -> ASTNestedName: id = self.get_identifier() @@ -3351,6 +4088,14 @@ def __init__(self, concept: ASTNestedName, self.concept = concept self.params = params + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateIntroduction): + return NotImplemented + return self.concept == other.concept and self.params == other.params + + def __hash__(self) -> int: + return hash((self.concept, self.params)) + def get_id(self, version: int) -> str: assert version >= 2 return ''.join([ @@ -3402,6 +4147,14 @@ def __init__(self, # templates is None means it's an explicit instantiation of a variable self.templates = templates + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateDeclarationPrefix): + return NotImplemented + return self.templates == other.templates + + def __hash__(self) -> int: + return hash(self.templates) + def get_requires_clause_in_last(self) -> ASTRequiresClause | None: if self.templates is None: return None @@ -3436,6 +4189,14 @@ class ASTRequiresClause(ASTBase): def __init__(self, expr: ASTExpression) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTRequiresClause): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return 'requires ' + transform(self.expr) @@ -3472,6 +4233,21 @@ def __init__(self, objectType: str, directiveType: str | None = None, # further changes will be made to this object self._newest_id_cache: str | None = None + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaration): + return NotImplemented + return ( + self.objectType == other.objectType + and self.directiveType == other.directiveType + and self.visibility == other.visibility + and self.templatePrefix == other.templatePrefix + and self.declaration == other.declaration + and self.trailingRequiresClause == other.trailingRequiresClause + and self.semicolon == other.semicolon + and self.symbol == other.symbol + and self.enumeratorScopedSymbol == other.enumeratorScopedSymbol + ) + def clone(self) -> ASTDeclaration: templatePrefixClone = self.templatePrefix.clone() if self.templatePrefix else None trailingRequiresClasueClone = self.trailingRequiresClause.clone() \ @@ -3627,6 +4403,14 @@ def __init__(self, nestedName: ASTNestedName, self.nestedName = nestedName self.templatePrefix = templatePrefix + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTNamespace): + return NotImplemented + return ( + self.nestedName == other.nestedName + and self.templatePrefix == other.templatePrefix + ) + def _stringify(self, transform: StringifyTransform) -> str: res = [] if self.templatePrefix: diff --git a/sphinx/util/cfamily.py b/sphinx/util/cfamily.py index c8879839ef0..65402d8aa1d 100644 --- a/sphinx/util/cfamily.py +++ b/sphinx/util/cfamily.py @@ -99,9 +99,6 @@ def __eq__(self, other: object) -> bool: return False return True - # Defining __hash__ = None is not strictly needed when __eq__ is defined. - __hash__ = None # type: ignore[assignment] - def clone(self) -> Any: return deepcopy(self) @@ -131,6 +128,14 @@ class ASTCPPAttribute(ASTAttribute): def __init__(self, arg: str) -> None: self.arg = arg + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTCPPAttribute): + return NotImplemented + return self.arg == other.arg + + def __hash__(self) -> int: + return hash(self.arg) + def _stringify(self, transform: StringifyTransform) -> str: return "[[" + self.arg + "]]" @@ -146,10 +151,13 @@ def __init__(self, name: str, args: ASTBaseParenExprList | None) -> None: self.args = args def __eq__(self, other: object) -> bool: - if type(other) is not ASTGnuAttribute: + if not isinstance(other, ASTGnuAttribute): return NotImplemented return self.name == other.name and self.args == other.args + def __hash__(self) -> int: + return hash((self.name, self.args)) + def _stringify(self, transform: StringifyTransform) -> str: res = [self.name] if self.args: @@ -161,6 +169,14 @@ class ASTGnuAttributeList(ASTAttribute): def __init__(self, attrs: list[ASTGnuAttribute]) -> None: self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTGnuAttributeList): + return NotImplemented + return self.attrs == other.attrs + + def __hash__(self) -> int: + return hash(self.attrs) + def _stringify(self, transform: StringifyTransform) -> str: res = ['__attribute__(('] first = True @@ -183,6 +199,14 @@ class ASTIdAttribute(ASTAttribute): def __init__(self, id: str) -> None: self.id = id + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTIdAttribute): + return NotImplemented + return self.id == other.id + + def __hash__(self) -> int: + return hash(self.id) + def _stringify(self, transform: StringifyTransform) -> str: return self.id @@ -197,6 +221,14 @@ def __init__(self, id: str, arg: str) -> None: self.id = id self.arg = arg + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTParenAttribute): + return NotImplemented + return self.id == other.id and self.arg == other.arg + + def __hash__(self) -> int: + return hash((self.id, self.arg)) + def _stringify(self, transform: StringifyTransform) -> str: return self.id + '(' + self.arg + ')' @@ -210,10 +242,13 @@ def __init__(self, attrs: list[ASTAttribute]) -> None: self.attrs = attrs def __eq__(self, other: object) -> bool: - if type(other) is not ASTAttributeList: + if not isinstance(other, ASTAttributeList): return NotImplemented return self.attrs == other.attrs + def __hash__(self) -> int: + return hash(self.attrs) + def __len__(self) -> int: return len(self.attrs) From 31b938db87833a56fe4d94077c0727de39522967 Mon Sep 17 00:00:00 2001 From: Adam Turner <9087854+aa-turner@users.noreply.github.com> Date: Tue, 23 Apr 2024 04:28:27 +0100 Subject: [PATCH 2/6] Improve ``ASTBaseBase.__eq__`` --- sphinx/util/cfamily.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/sphinx/util/cfamily.py b/sphinx/util/cfamily.py index 65402d8aa1d..f7800405562 100644 --- a/sphinx/util/cfamily.py +++ b/sphinx/util/cfamily.py @@ -90,14 +90,11 @@ class NoOldIdError(Exception): class ASTBaseBase: def __eq__(self, other: object) -> bool: if type(self) is not type(other): - return False + return NotImplemented try: - for key, value in self.__dict__.items(): - if value != getattr(other, key): - return False + return self.__dict__ == other.__dict__ except AttributeError: return False - return True def clone(self) -> Any: return deepcopy(self) From afc6f2ac8c78658196e63c221bc14b6e757f2428 Mon Sep 17 00:00:00 2001 From: Adam Turner <9087854+aa-turner@users.noreply.github.com> Date: Tue, 23 Apr 2024 04:32:10 +0100 Subject: [PATCH 3/6] is_anon() -> is_anonymous --- sphinx/domains/c/_ast.py | 7 ++++--- sphinx/domains/cpp/_ast.py | 19 +++++++++++-------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/sphinx/domains/c/_ast.py b/sphinx/domains/c/_ast.py index 2ff3c270f19..5ffd52cd7bb 100644 --- a/sphinx/domains/c/_ast.py +++ b/sphinx/domains/c/_ast.py @@ -42,6 +42,7 @@ def __init__(self, identifier: str) -> None: assert identifier is not None assert len(identifier) != 0 self.identifier = identifier + self.is_anonymous = identifier[0] == '@' # ASTBaseBase already implements this method, # but specialising it here improves performance @@ -51,7 +52,7 @@ def __eq__(self, other: object) -> bool: return self.identifier == other.identifier def is_anon(self) -> bool: - return self.identifier[0] == '@' + return self.is_anonymous # and this is where we finally make a difference between __str__ and the display string @@ -59,13 +60,13 @@ def __str__(self) -> str: return self.identifier def get_display_string(self) -> str: - return "[anonymous]" if self.is_anon() else self.identifier + return "[anonymous]" if self.is_anonymous else self.identifier def describe_signature(self, signode: TextElement, mode: str, env: BuildEnvironment, prefix: str, symbol: Symbol) -> None: # note: slightly different signature of describe_signature due to the prefix verify_description_mode(mode) - if self.is_anon(): + if self.is_anonymous: node = addnodes.desc_sig_name(text="[anonymous]") else: node = addnodes.desc_sig_name(self.identifier, self.identifier) diff --git a/sphinx/domains/cpp/_ast.py b/sphinx/domains/cpp/_ast.py index 579e330ebbd..8ec77b0bd98 100644 --- a/sphinx/domains/cpp/_ast.py +++ b/sphinx/domains/cpp/_ast.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar, Literal from docutils import nodes @@ -48,6 +48,7 @@ def __init__(self, identifier: str) -> None: assert identifier is not None assert len(identifier) != 0 self.identifier = identifier + self.is_anonymous = identifier[0] == '@' # ASTBaseBase already implements this method, # but specialising it here improves performance @@ -63,10 +64,10 @@ def _stringify(self, transform: StringifyTransform) -> str: return transform(self.identifier) def is_anon(self) -> bool: - return self.identifier[0] == '@' + return self.is_anonymous def get_id(self, version: int) -> str: - if self.is_anon() and version < 3: + if self.is_anonymous and version < 3: raise NoOldIdError if version == 1: if self.identifier == 'size_t': @@ -79,7 +80,7 @@ def get_id(self, version: int) -> str: # a destructor, just use an arbitrary version of dtors return 'D0' else: - if self.is_anon(): + if self.is_anonymous: return 'Ut%d_%s' % (len(self.identifier) - 1, self.identifier[1:]) else: return str(len(self.identifier)) + self.identifier @@ -90,12 +91,12 @@ def __str__(self) -> str: return self.identifier def get_display_string(self) -> str: - return "[anonymous]" if self.is_anon() else self.identifier + return "[anonymous]" if self.is_anonymous else self.identifier def describe_signature(self, signode: TextElement, mode: str, env: BuildEnvironment, prefix: str, templateArgs: str, symbol: Symbol) -> None: verify_description_mode(mode) - if self.is_anon(): + if self.is_anonymous: node = addnodes.desc_sig_name(text="[anonymous]") else: node = addnodes.desc_sig_name(self.identifier, self.identifier) @@ -121,7 +122,7 @@ def describe_signature(self, signode: TextElement, mode: str, env: BuildEnvironm # the target is 'operator""id' instead of just 'id' assert len(prefix) == 0 assert len(templateArgs) == 0 - assert not self.is_anon() + assert not self.is_anonymous targetText = 'operator""' + self.identifier pnode = addnodes.pending_xref('', refdomain='cpp', reftype='identifier', @@ -1454,6 +1455,8 @@ def describe_signature(self, signode: TextElement, mode: str, ################################################################################ class ASTOperator(ASTBase): + is_anonymous: ClassVar[Literal[False]] = False + def __eq__(self, other: object) -> bool: raise NotImplementedError(repr(self)) @@ -1461,7 +1464,7 @@ def __hash__(self) -> int: raise NotImplementedError(repr(self)) def is_anon(self) -> bool: - return False + return self.is_anonymous def is_operator(self) -> bool: return True From c5e72b3e1bfdcc1ccb3aa6cfe8b143d19d100b9d Mon Sep 17 00:00:00 2001 From: Adam Turner <9087854+aa-turner@users.noreply.github.com> Date: Tue, 23 Apr 2024 04:40:31 +0100 Subject: [PATCH 4/6] Rename ``ASTIdentifier.identifier`` to ``ASTIdentifier.name`` --- sphinx/domains/c/_ast.py | 30 ++++++++++++++++-------- sphinx/domains/cpp/_ast.py | 48 +++++++++++++++++++++++--------------- 2 files changed, 49 insertions(+), 29 deletions(-) diff --git a/sphinx/domains/c/_ast.py b/sphinx/domains/c/_ast.py index 5ffd52cd7bb..6082a56fead 100644 --- a/sphinx/domains/c/_ast.py +++ b/sphinx/domains/c/_ast.py @@ -1,5 +1,7 @@ from __future__ import annotations +import sys +import warnings from typing import TYPE_CHECKING, Any, Union, cast from docutils import nodes @@ -38,18 +40,18 @@ def describe_signature(self, signode: TextElement, mode: str, ################################################################################ class ASTIdentifier(ASTBaseBase): - def __init__(self, identifier: str) -> None: - assert identifier is not None - assert len(identifier) != 0 - self.identifier = identifier - self.is_anonymous = identifier[0] == '@' + def __init__(self, name: str) -> None: + if not isinstance(name, str) or len(name) == 0: + raise AssertionError + self.name = sys.intern(name) + self.is_anonymous = name[0] == '@' # ASTBaseBase already implements this method, # but specialising it here improves performance def __eq__(self, other: object) -> bool: if not isinstance(other, ASTIdentifier): return NotImplemented - return self.identifier == other.identifier + return self.name == other.name def is_anon(self) -> bool: return self.is_anonymous @@ -57,10 +59,10 @@ def is_anon(self) -> bool: # and this is where we finally make a difference between __str__ and the display string def __str__(self) -> str: - return self.identifier + return self.name def get_display_string(self) -> str: - return "[anonymous]" if self.is_anonymous else self.identifier + return "[anonymous]" if self.is_anonymous else self.name def describe_signature(self, signode: TextElement, mode: str, env: BuildEnvironment, prefix: str, symbol: Symbol) -> None: @@ -69,9 +71,9 @@ def describe_signature(self, signode: TextElement, mode: str, env: BuildEnvironm if self.is_anonymous: node = addnodes.desc_sig_name(text="[anonymous]") else: - node = addnodes.desc_sig_name(self.identifier, self.identifier) + node = addnodes.desc_sig_name(self.name, self.name) if mode == 'markType': - targetText = prefix + self.identifier + targetText = prefix + self.name pnode = addnodes.pending_xref('', refdomain='c', reftype='identifier', reftarget=targetText, modname=None, @@ -88,6 +90,14 @@ def describe_signature(self, signode: TextElement, mode: str, env: BuildEnvironm else: raise Exception('Unknown description mode: %s' % mode) + @property + def identifier(self) -> str: + warnings.warn( + '`ASTIdentifier.identifier` is deprecated, use `ASTIdentifier.name` instead', + DeprecationWarning, stacklevel=2, + ) + return self.name + class ASTNestedName(ASTBase): def __init__(self, names: list[ASTIdentifier], rooted: bool) -> None: diff --git a/sphinx/domains/cpp/_ast.py b/sphinx/domains/cpp/_ast.py index 8ec77b0bd98..141d5112c8e 100644 --- a/sphinx/domains/cpp/_ast.py +++ b/sphinx/domains/cpp/_ast.py @@ -1,5 +1,7 @@ from __future__ import annotations +import sys +import warnings from typing import TYPE_CHECKING, Any, ClassVar, Literal from docutils import nodes @@ -44,24 +46,24 @@ class ASTBase(ASTBaseBase): ################################################################################ class ASTIdentifier(ASTBase): - def __init__(self, identifier: str) -> None: - assert identifier is not None - assert len(identifier) != 0 - self.identifier = identifier - self.is_anonymous = identifier[0] == '@' + def __init__(self, name: str) -> None: + if not isinstance(name, str) or len(name) == 0: + raise AssertionError + self.name = sys.intern(name) + self.is_anonymous = name[0] == '@' # ASTBaseBase already implements this method, # but specialising it here improves performance def __eq__(self, other: object) -> bool: if not isinstance(other, ASTIdentifier): return NotImplemented - return self.identifier == other.identifier + return self.name == other.name def __hash__(self) -> int: - return hash(self.identifier) + return hash(self.name) def _stringify(self, transform: StringifyTransform) -> str: - return transform(self.identifier) + return transform(self.name) def is_anon(self) -> bool: return self.is_anonymous @@ -70,28 +72,28 @@ def get_id(self, version: int) -> str: if self.is_anonymous and version < 3: raise NoOldIdError if version == 1: - if self.identifier == 'size_t': + if self.name == 'size_t': return 's' else: - return self.identifier - if self.identifier == "std": + return self.name + if self.name == "std": return 'St' - elif self.identifier[0] == "~": + elif self.name[0] == "~": # a destructor, just use an arbitrary version of dtors return 'D0' else: if self.is_anonymous: - return 'Ut%d_%s' % (len(self.identifier) - 1, self.identifier[1:]) + return 'Ut%d_%s' % (len(self.name) - 1, self.name[1:]) else: - return str(len(self.identifier)) + self.identifier + return str(len(self.name)) + self.name # and this is where we finally make a difference between __str__ and the display string def __str__(self) -> str: - return self.identifier + return self.name def get_display_string(self) -> str: - return "[anonymous]" if self.is_anonymous else self.identifier + return "[anonymous]" if self.is_anonymous else self.name def describe_signature(self, signode: TextElement, mode: str, env: BuildEnvironment, prefix: str, templateArgs: str, symbol: Symbol) -> None: @@ -99,9 +101,9 @@ def describe_signature(self, signode: TextElement, mode: str, env: BuildEnvironm if self.is_anonymous: node = addnodes.desc_sig_name(text="[anonymous]") else: - node = addnodes.desc_sig_name(self.identifier, self.identifier) + node = addnodes.desc_sig_name(self.name, self.name) if mode == 'markType': - targetText = prefix + self.identifier + templateArgs + targetText = prefix + self.name + templateArgs pnode = addnodes.pending_xref('', refdomain='cpp', reftype='identifier', reftarget=targetText, modname=None, @@ -123,7 +125,7 @@ def describe_signature(self, signode: TextElement, mode: str, env: BuildEnvironm assert len(prefix) == 0 assert len(templateArgs) == 0 assert not self.is_anonymous - targetText = 'operator""' + self.identifier + targetText = 'operator""' + self.name pnode = addnodes.pending_xref('', refdomain='cpp', reftype='identifier', reftarget=targetText, modname=None, @@ -134,6 +136,14 @@ def describe_signature(self, signode: TextElement, mode: str, env: BuildEnvironm else: raise Exception('Unknown description mode: %s' % mode) + @property + def identifier(self) -> str: + warnings.warn( + '`ASTIdentifier.identifier` is deprecated, use `ASTIdentifier.name` instead', + DeprecationWarning, stacklevel=2, + ) + return self.name + class ASTNestedNameElement(ASTBase): def __init__(self, identOrOp: ASTIdentifier | ASTOperator, From 4a2f8fcc0a7b101884ef342ac0ab78411a2e1753 Mon Sep 17 00:00:00 2001 From: Adam Turner <9087854+aa-turner@users.noreply.github.com> Date: Tue, 23 Apr 2024 04:40:47 +0100 Subject: [PATCH 5/6] Add a ``repr`` for ``Symbol`` --- sphinx/domains/c/_symbol.py | 3 +++ sphinx/domains/cpp/_symbol.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/sphinx/domains/c/_symbol.py b/sphinx/domains/c/_symbol.py index 5205204c4ce..fd1c0d05d6d 100644 --- a/sphinx/domains/c/_symbol.py +++ b/sphinx/domains/c/_symbol.py @@ -114,6 +114,9 @@ def __init__( # Do symbol addition after self._children has been initialised. self._add_function_params() + def __repr__(self) -> str: + return f'' + def _fill_empty(self, declaration: ASTDeclaration, docname: str, line: int) -> None: self._assert_invariants() assert self.declaration is None diff --git a/sphinx/domains/cpp/_symbol.py b/sphinx/domains/cpp/_symbol.py index 4caa43070a1..14c8f5fe672 100644 --- a/sphinx/domains/cpp/_symbol.py +++ b/sphinx/domains/cpp/_symbol.py @@ -155,6 +155,9 @@ def __init__(self, parent: Symbol | None, # Do symbol addition after self._children has been initialised. self._add_template_and_function_params() + def __repr__(self) -> str: + return f'' + def _fill_empty(self, declaration: ASTDeclaration, docname: str, line: int) -> None: self._assert_invariants() assert self.declaration is None From 5c4db0bbb97772f7f639ecc9cf16bff0c447bc68 Mon Sep 17 00:00:00 2001 From: Adam Turner <9087854+aa-turner@users.noreply.github.com> Date: Tue, 23 Apr 2024 04:41:18 +0100 Subject: [PATCH 6/6] Serialisation improvements in ``cfamily`` --- sphinx/util/cfamily.py | 30 ++++++++++-------------------- 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/sphinx/util/cfamily.py b/sphinx/util/cfamily.py index f7800405562..53f38685a26 100644 --- a/sphinx/util/cfamily.py +++ b/sphinx/util/cfamily.py @@ -109,7 +109,7 @@ def get_display_string(self) -> str: return self._stringify(lambda ast: ast.get_display_string()) def __repr__(self) -> str: - return '<%s>' % self.__class__.__name__ + return f'<{self.__class__.__name__}: {self._stringify(repr)}>' ################################################################################ @@ -134,7 +134,7 @@ def __hash__(self) -> int: return hash(self.arg) def _stringify(self, transform: StringifyTransform) -> str: - return "[[" + self.arg + "]]" + return f"[[{self.arg}]]" def describe_signature(self, signode: TextElement) -> None: signode.append(addnodes.desc_sig_punctuation('[[', '[[')) @@ -156,10 +156,9 @@ def __hash__(self) -> int: return hash((self.name, self.args)) def _stringify(self, transform: StringifyTransform) -> str: - res = [self.name] if self.args: - res.append(transform(self.args)) - return ''.join(res) + return self.name + transform(self.args) + return self.name class ASTGnuAttributeList(ASTAttribute): @@ -175,19 +174,11 @@ def __hash__(self) -> int: return hash(self.attrs) def _stringify(self, transform: StringifyTransform) -> str: - res = ['__attribute__(('] - first = True - for attr in self.attrs: - if not first: - res.append(', ') - first = False - res.append(transform(attr)) - res.append('))') - return ''.join(res) + attrs = ', '.join(map(transform, self.attrs)) + return f'__attribute__(({attrs}))' def describe_signature(self, signode: TextElement) -> None: - txt = str(self) - signode.append(nodes.Text(txt)) + signode.append(nodes.Text(str(self))) class ASTIdAttribute(ASTAttribute): @@ -227,11 +218,10 @@ def __hash__(self) -> int: return hash((self.id, self.arg)) def _stringify(self, transform: StringifyTransform) -> str: - return self.id + '(' + self.arg + ')' + return f'{self.id}({self.arg})' def describe_signature(self, signode: TextElement) -> None: - txt = str(self) - signode.append(nodes.Text(txt)) + signode.append(nodes.Text(str(self))) class ASTAttributeList(ASTBaseBase): @@ -253,7 +243,7 @@ def __add__(self, other: ASTAttributeList) -> ASTAttributeList: return ASTAttributeList(self.attrs + other.attrs) def _stringify(self, transform: StringifyTransform) -> str: - return ' '.join(transform(attr) for attr in self.attrs) + return ' '.join(map(transform, self.attrs)) def describe_signature(self, signode: TextElement) -> None: if len(self.attrs) == 0: