Skip to content

Commit c09480e

Browse files
authored
🐛 fix(parser): prevent directive side-effects in snippet parsing (#624)
Extensions like sphinx-needs register unique identifiers when their directives execute. Since `get_insert_index` parses docstrings a second time to locate field lists for `:rtype:` insertion, those directive handlers run twice — causing duplicate ID errors and other problems for any extension with non-idempotent side-effects. Rather than replacing the RST parser with regex-based field detection (which can't handle definition lists and other complex RST structures), we intercept the directive lookup during snippet parsing and replace any non-builtin directive with a no-op handler. The builtin docutils set is captured at import time before extensions register theirs. The document tree structure is fully preserved for the existing insertion logic. Fixes #510
1 parent 5ea2ec7 commit c09480e

File tree

2 files changed

+137
-35
lines changed

2 files changed

+137
-35
lines changed

src/sphinx_autodoc_typehints/__init__.py

Lines changed: 67 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,25 @@
1616

1717
from docutils import nodes
1818
from docutils.frontend import get_default_settings
19+
from docutils.parsers.rst import Directive, directives
20+
from docutils.utils import new_document
1921
from sphinx.ext.autodoc.mock import mock
2022
from sphinx.parsers import RSTParser
2123
from sphinx.util import logging, rst
24+
from sphinx.util.docutils import sphinx_domains
2225
from sphinx.util.inspect import TypeAliasForwardRef, stringify_signature
2326
from sphinx.util.inspect import signature as sphinx_signature
2427

25-
from ._parser import parse
28+
from ._parser import _RstSnippetParser, parse
2629
from .patches import install_patches
2730
from .version import __version__
2831

2932
if TYPE_CHECKING:
33+
import optparse
3034
from ast import FunctionDef, Module, stmt
3135
from collections.abc import Callable
3236

37+
from docutils.frontend import Values
3338
from docutils.nodes import Node
3439
from docutils.parsers.rst import states
3540
from sphinx.application import Sphinx
@@ -38,6 +43,8 @@
3843
from sphinx.ext.autodoc import Options
3944

4045
_LOGGER = logging.getLogger(__name__)
46+
47+
_BUILTIN_DIRECTIVES = frozenset(directives._directive_registry) # noqa: SLF001
4148
_PYDATA_ANNOTS_TYPING = {
4249
"Any",
4350
"AnyStr",
@@ -956,50 +963,23 @@ class InsertIndexInfo:
956963
PARAM_SYNONYMS = ("param ", "parameter ", "arg ", "argument ", "keyword ", "kwarg ", "kwparam ")
957964

958965

959-
def node_line_no(node: Node) -> int | None:
960-
"""
961-
Get the 1-indexed line on which the node starts if possible. If not, return None.
962-
963-
Descend through the first children until we locate one with a line number or return None if None of them have one.
964-
965-
I'm not aware of any rst on which this returns None, to find out would require a more detailed analysis of the
966-
docutils rst parser source code. An example where the node doesn't have a line number but the first child does is
967-
all `definition_list` nodes. It seems like bullet_list and option_list get line numbers, but enum_list also doesn't.
968-
"""
969-
if node is None:
970-
return None
971-
972-
while node.line is None and node.children:
973-
node = node.children[0]
974-
return node.line
975-
976-
977-
def tag_name(node: Node) -> str:
978-
return node.tagname
979-
980-
981966
def get_insert_index(app: Sphinx, lines: list[str]) -> InsertIndexInfo | None:
982967
# 1. If there is an existing :rtype: anywhere, don't insert anything.
983968
if any(line.startswith(":rtype:") for line in lines):
984969
return None
985970

986-
# 2. If there is a :returns: anywhere, either modify that line or insert
987-
# just before it.
971+
# 2. If there is a :returns: anywhere, either modify that line or insert just before it.
988972
for at, line in enumerate(lines):
989973
if line.startswith((":return:", ":returns:")):
990974
return InsertIndexInfo(insert_index=at, found_return=True)
991975

992976
# 3. Insert after the parameters.
993-
# To find the parameters, parse as a docutils tree.
994977
settings = get_default_settings(RSTParser) # type: ignore[arg-type]
995978
settings.env = app.env
996-
doc = parse("\n".join(lines), settings)
979+
doc = _safe_parse("\n".join(lines), settings)
997980

998-
# Find a top level child which is a field_list that contains a field whose
999-
# name starts with one of the PARAM_SYNONYMS. This is the parameter list. We
1000-
# hope there is at most of these.
1001981
for child in doc.children:
1002-
if tag_name(child) != "field_list":
982+
if _tag_name(child) != "field_list":
1003983
continue
1004984

1005985
if not any(c.children[0].astext().startswith(PARAM_SYNONYMS) for c in child.children):
@@ -1010,24 +990,76 @@ def get_insert_index(app: Sphinx, lines: list[str]) -> InsertIndexInfo | None:
1010990
# If there is a next sibling but we can't locate a line number, insert
1011991
# at end. (I don't know of any input where this happens.)
1012992
next_sibling = child.next_node(descend=False, siblings=True)
1013-
line_no = node_line_no(next_sibling) if next_sibling else None
993+
line_no = _node_line_no(next_sibling) if next_sibling else None
1014994
at = max(line_no - 2, 0) if line_no else len(lines)
1015995
return InsertIndexInfo(insert_index=at, found_param=True)
1016996

1017997
# 4. Insert before examples
1018998
for child in doc.children:
1019-
if tag_name(child) in {"literal_block", "paragraph", "field_list"}:
999+
if _tag_name(child) in {"literal_block", "paragraph", "field_list"}:
10201000
continue
1021-
line_no = node_line_no(child)
1001+
line_no = _node_line_no(child)
10221002
at = max(line_no - 2, 0) if line_no else len(lines)
1023-
if lines[at - 1]: # skip if something on this line
1003+
if lines[at - 1]:
10241004
break
10251005
return InsertIndexInfo(insert_index=at, found_directive=True)
10261006

10271007
# 5. Otherwise, insert at end
10281008
return InsertIndexInfo(insert_index=len(lines))
10291009

10301010

1011+
def _safe_parse(inputstr: str, settings: Values | optparse.Values) -> nodes.document:
1012+
"""
1013+
Parse RST without triggering extension directive side-effects.
1014+
1015+
Replaces non-builtin directive lookups with a no-op handler during parsing
1016+
to prevent duplicate ID registration and other side-effects from third-party
1017+
extensions like sphinx-needs.
1018+
"""
1019+
original_lookup = directives.directive
1020+
1021+
def _safe_directive_lookup(
1022+
directive_name: str,
1023+
language_module: Any,
1024+
document: Any,
1025+
) -> tuple[type[Directive] | None, list[Any]]:
1026+
cls, messages = original_lookup(directive_name, language_module, document)
1027+
if cls is not None and directive_name not in _BUILTIN_DIRECTIVES:
1028+
return _NoOpDirective, messages
1029+
return cls, messages
1030+
1031+
doc = new_document("", settings=settings) # ty: ignore[invalid-argument-type]
1032+
with sphinx_domains(settings.env):
1033+
directives.directive = _safe_directive_lookup # type: ignore[assignment]
1034+
try:
1035+
parser = _RstSnippetParser()
1036+
parser.parse(inputstr, doc)
1037+
finally:
1038+
directives.directive = original_lookup
1039+
return doc
1040+
1041+
1042+
class _NoOpDirective(Directive):
1043+
has_content = True
1044+
optional_arguments = 99
1045+
final_argument_whitespace = True
1046+
1047+
def run(self) -> list[nodes.Node]: # noqa: PLR6301
1048+
return []
1049+
1050+
1051+
def _node_line_no(node: Node) -> int | None:
1052+
if node is None:
1053+
return None
1054+
while node.line is None and node.children:
1055+
node = node.children[0]
1056+
return node.line
1057+
1058+
1059+
def _tag_name(node: Node) -> str:
1060+
return node.tagname
1061+
1062+
10311063
def _inject_rtype( # noqa: C901, PLR0911, PLR0913, PLR0917
10321064
type_hints: dict[str, Any],
10331065
original_obj: Any,

tests/test_safe_parse.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""Tests that snippet parsing doesn't trigger extension directive side-effects."""
2+
3+
from __future__ import annotations
4+
5+
import sys
6+
from pathlib import Path
7+
from textwrap import dedent
8+
from typing import TYPE_CHECKING, ClassVar
9+
10+
import pytest
11+
from docutils.parsers.rst import Directive, directives
12+
13+
if TYPE_CHECKING:
14+
from io import StringIO
15+
16+
from sphinx.testing.util import SphinxTestApp
17+
18+
19+
@pytest.mark.sphinx("text", testroot="integration")
20+
def test_extension_directive_not_executed_during_snippet_parse(
21+
app: SphinxTestApp,
22+
status: StringIO,
23+
warning: StringIO, # noqa: ARG001
24+
monkeypatch: pytest.MonkeyPatch,
25+
) -> None:
26+
"""A non-builtin directive in a docstring should only execute once (during the real build)."""
27+
directives.register_directive("tracking-directive", _TrackingDirective)
28+
_TrackingDirective.executions.clear()
29+
30+
(Path(app.srcdir) / "index.rst").write_text(
31+
dedent("""\
32+
Test
33+
====
34+
35+
.. autofunction:: mod.func_with_tracking_directive
36+
""")
37+
)
38+
39+
src = dedent("""\
40+
def func_with_tracking_directive(x: int) -> int:
41+
\"\"\"Do something.
42+
43+
:param x: A number.
44+
45+
.. tracking-directive::
46+
47+
unique-id-123
48+
49+
\"\"\"
50+
return x
51+
""")
52+
exec(compile(src, "<test>", "exec"), (mod := {})) # noqa: S102
53+
fake_module = type(sys)("mod")
54+
fake_module.__dict__.update(mod)
55+
monkeypatch.setitem(sys.modules, "mod", fake_module)
56+
57+
app.build()
58+
assert "build succeeded" in status.getvalue()
59+
assert _TrackingDirective.executions.count("unique-id-123") == 1
60+
61+
62+
class _TrackingDirective(Directive):
63+
"""Directive that records each execution to detect double-processing."""
64+
65+
has_content = True
66+
executions: ClassVar[list[str]] = []
67+
68+
def run(self) -> list:
69+
_TrackingDirective.executions.append(self.content[0] if self.content else "")
70+
return []

0 commit comments

Comments
 (0)