diff --git a/src/docstub-stubs/_stubs.pyi b/src/docstub-stubs/_stubs.pyi index 87c2554..00d8104 100644 --- a/src/docstub-stubs/_stubs.pyi +++ b/src/docstub-stubs/_stubs.pyi @@ -2,6 +2,7 @@ import enum import logging +from collections.abc import Sequence from dataclasses import dataclass from functools import wraps from pathlib import Path @@ -34,6 +35,9 @@ class ScopeType(enum.StrEnum): CLASSMETHOD = "classmethod" STATICMETHOD = "staticmethod" +_dataclass_name: cstm.Name +_dataclass_matcher: cstm.ClassDef + @dataclass(slots=True, frozen=True) class _Scope: @@ -51,7 +55,7 @@ class _Scope: def _get_docstring_node( node: cst.FunctionDef | cst.ClassDef | cst.Module, -) -> cst.SimpleString | cst.ConcatenatedString | None: ... +) -> tuple[cst.SimpleString | cst.ConcatenatedString | None, str | None]: ... def _log_error_with_line_context(cls: Py2StubTransformer) -> Py2StubTransformer: ... def _docstub_comment_directives(cls: Py2StubTransformer) -> Py2StubTransformer: ... def _inline_node_as_code(node: cst.CSTNode) -> str: ... diff --git a/src/docstub/_stubs.py b/src/docstub/_stubs.py index 862e40e..94ce476 100644 --- a/src/docstub/_stubs.py +++ b/src/docstub/_stubs.py @@ -7,6 +7,7 @@ import enum import logging +from collections.abc import Sequence from dataclasses import dataclass from functools import wraps from typing import ClassVar @@ -54,6 +55,22 @@ class ScopeType(enum.StrEnum): # docstub: on +# To be used with `libcst.matchers.matches()` to guess if a node is a dataclass +# See `test_dataclass_matcher` for supported cases +_dataclass_name: cstm.Name = cstm.Name("dataclass") +_dataclass_matcher: cstm.ClassDef = cstm.ClassDef( + decorators=[ + cstm.Decorator( + decorator=( + _dataclass_name + | cstm.Call(func=_dataclass_name | cstm.Attribute(attr=_dataclass_name)) + | cstm.Attribute(attr=_dataclass_name) + ) + ), + ] +) + + # TODO use `libcst.metadata.ScopeProvider` instead @dataclass(slots=True, frozen=True) class _Scope: @@ -81,14 +98,8 @@ def is_class_init(self) -> bool: @property def is_dataclass(self) -> bool: - if cstm.matches(self.node, cstm.ClassDef()): - # Determine if dataclass - decorators = cstm.findall(self.node, cstm.Decorator()) - is_dataclass = any( - cstm.findall(d, cstm.Name("dataclass")) for d in decorators - ) - return is_dataclass - return False + is_dataclass = cstm.matches(self.node, _dataclass_matcher) + return is_dataclass def _get_docstring_node(node): @@ -106,23 +117,35 @@ def _get_docstring_node(node): ------- docstring_node : cst.SimpleString | cst.ConcatenatedString | None The node of the docstring if found. + docstring_value : str | None + The value of the docstring if found. """ - docstring_node = None - - docstring = node.get_docstring(clean=False) - if docstring: - # Workaround to find the exact postion of a docstring - # by using its node - string_nodes = cstm.findall( - node, cstm.SimpleString() | cstm.ConcatenatedString() - ) - matching_nodes = [ - node for node in string_nodes if node.evaluated_value == docstring - ] - assert len(matching_nodes) == 1 - docstring_node = matching_nodes[0] - return docstring_node + # Copied from https://github.com/Instagram/LibCST/blob/9275a8bf7875d08659ce7b266860138bba633410/libcst/_nodes/statement.py#L1669 + body = node.body + if isinstance(body, Sequence): + if body: + expr = body[0] + else: + return (None, None) + else: + expr = body + while isinstance(expr, (cst.BaseSuite, cst.SimpleStatementLine)): + if len(expr.body) == 0: + return (None, None) + expr = expr.body[0] + if not isinstance(expr, cst.Expr): + return (None, None) + + docstring_node = expr.value + if isinstance(docstring_node, (cst.SimpleString, cst.ConcatenatedString)): + docstring_value = docstring_node.evaluated_value + else: + return (None, None) + if isinstance(docstring_value, bytes): + return (None, None) + + return docstring_node, docstring_value def _log_error_with_line_context(cls): @@ -897,7 +920,7 @@ def _annotations_from_node(self, node): """ annotations = None - docstring_node = _get_docstring_node(node) + docstring_node, docstring_value = _get_docstring_node(node) if docstring_node: position = self.get_metadata( cst.metadata.PositionProvider, docstring_node @@ -907,7 +930,7 @@ def _annotations_from_node(self, node): ) try: annotations = DocstringAnnotations( - docstring_node.evaluated_value, + docstring_value, transformer=self.transformer, reporter=reporter, ) diff --git a/tests/test_stubs.py b/tests/test_stubs.py index 10dc36c..8d49dcf 100644 --- a/tests/test_stubs.py +++ b/tests/test_stubs.py @@ -6,7 +6,7 @@ import libcst.matchers as cstm import pytest -from docstub._stubs import Py2StubTransformer, _get_docstring_node +from docstub._stubs import Py2StubTransformer, _dataclass_matcher, _get_docstring_node class Test_get_docstring_node: @@ -761,3 +761,27 @@ def foo(*args: str, **kwargs: int) -> None: ... transformer = Py2StubTransformer() result = transformer.python_to_stub(source) assert expected == result + + +@pytest.mark.parametrize( + ("decorators", "expected"), + [ + ("@dataclass", True), + ("@dataclass(frozen=True)", True), + ("@dataclasses.dataclass(frozen=True)", True), + ("@dc.dataclass", True), + ("", False), + ("@other", False), + ("@other(dataclass=True)", False), + ], +) +def test_dataclass_matcher(decorators, expected): + source = dedent( + """ + {decorators} + class Foo: + pass + """ + ).format(decorators=decorators) + class_def = cst.parse_statement(source) + assert cstm.matches(class_def, _dataclass_matcher) is expected