Skip to content

Commit bea8b6b

Browse files
authored
[internals] improve type safety when using NodeMatcher (#12034)
1 parent 265ffee commit bea8b6b

File tree

5 files changed

+37
-31
lines changed

5 files changed

+37
-31
lines changed

sphinx/builders/html/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def run(self, **kwargs: Any) -> None:
4747
matcher = NodeMatcher(nodes.literal, classes=["kbd"])
4848
# this list must be pre-created as during iteration new nodes
4949
# are added which match the condition in the NodeMatcher.
50-
for node in list(self.document.findall(matcher)): # type: nodes.literal
50+
for node in list(matcher.findall(self.document)):
5151
parts = self.pattern.split(node[-1].astext())
5252
if len(parts) == 1 or self.is_multiwords_key(parts):
5353
continue

sphinx/builders/latex/transforms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class FootnoteDocnameUpdater(SphinxTransform):
3737

3838
def apply(self, **kwargs: Any) -> None:
3939
matcher = NodeMatcher(*self.TARGET_NODES)
40-
for node in self.document.findall(matcher): # type: Element
40+
for node in matcher.findall(self.document):
4141
node['docname'] = self.env.docname
4242

4343

@@ -538,7 +538,7 @@ class CitationReferenceTransform(SphinxPostTransform):
538538
def run(self, **kwargs: Any) -> None:
539539
domain = cast(CitationDomain, self.env.get_domain('citation'))
540540
matcher = NodeMatcher(addnodes.pending_xref, refdomain='citation', reftype='ref')
541-
for node in self.document.findall(matcher): # type: addnodes.pending_xref
541+
for node in matcher.findall(self.document):
542542
docname, labelid, _ = domain.citations.get(node['reftarget'], ('', '', 0))
543543
if docname:
544544
citation_ref = nodes.citation_reference('', '', *node.children,
@@ -574,7 +574,7 @@ class LiteralBlockTransform(SphinxPostTransform):
574574

575575
def run(self, **kwargs: Any) -> None:
576576
matcher = NodeMatcher(nodes.container, literal_block=True)
577-
for node in self.document.findall(matcher): # type: nodes.container
577+
for node in matcher.findall(self.document):
578578
newnode = captioned_literal_block('', *node.children, **node.attributes)
579579
node.replace_self(newnode)
580580

sphinx/transforms/i18n.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def update_title_mapping(self) -> bool:
182182

183183
# replace target's refname to new target name
184184
matcher = NodeMatcher(nodes.target, refname=old_name)
185-
for old_target in self.document.findall(matcher): # type: nodes.target
185+
for old_target in matcher.findall(self.document):
186186
old_target['refname'] = new_name
187187

188188
processed = True
@@ -198,10 +198,8 @@ def list_replace_or_append(lst: list[N], old: N, new: N) -> None:
198198
lst.append(new)
199199

200200
is_autofootnote_ref = NodeMatcher(nodes.footnote_reference, auto=Any)
201-
old_foot_refs: list[nodes.footnote_reference] = [
202-
*self.node.findall(is_autofootnote_ref)]
203-
new_foot_refs: list[nodes.footnote_reference] = [
204-
*self.patch.findall(is_autofootnote_ref)]
201+
old_foot_refs = list(is_autofootnote_ref.findall(self.node))
202+
new_foot_refs = list(is_autofootnote_ref.findall(self.patch))
205203
self.compare_references(old_foot_refs, new_foot_refs,
206204
__('inconsistent footnote references in translated message.' +
207205
' original: {0}, translated: {1}'))
@@ -240,8 +238,8 @@ def update_refnamed_references(self) -> None:
240238
# * use translated refname for section refname.
241239
# * inline reference "`Python <...>`_" has no 'refname'.
242240
is_refnamed_ref = NodeMatcher(nodes.reference, refname=Any)
243-
old_refs: list[nodes.reference] = [*self.node.findall(is_refnamed_ref)]
244-
new_refs: list[nodes.reference] = [*self.patch.findall(is_refnamed_ref)]
241+
old_refs = list(is_refnamed_ref.findall(self.node))
242+
new_refs = list(is_refnamed_ref.findall(self.patch))
245243
self.compare_references(old_refs, new_refs,
246244
__('inconsistent references in translated message.' +
247245
' original: {0}, translated: {1}'))
@@ -264,10 +262,8 @@ def update_refnamed_references(self) -> None:
264262
def update_refnamed_footnote_references(self) -> None:
265263
# refnamed footnote should use original 'ids'.
266264
is_refnamed_footnote_ref = NodeMatcher(nodes.footnote_reference, refname=Any)
267-
old_foot_refs: list[nodes.footnote_reference] = [*self.node.findall(
268-
is_refnamed_footnote_ref)]
269-
new_foot_refs: list[nodes.footnote_reference] = [*self.patch.findall(
270-
is_refnamed_footnote_ref)]
265+
old_foot_refs = list(is_refnamed_footnote_ref.findall(self.node))
266+
new_foot_refs = list(is_refnamed_footnote_ref.findall(self.patch))
271267
refname_ids_map: dict[str, list[str]] = {}
272268
self.compare_references(old_foot_refs, new_foot_refs,
273269
__('inconsistent footnote references in translated message.' +
@@ -282,8 +278,8 @@ def update_refnamed_footnote_references(self) -> None:
282278
def update_citation_references(self) -> None:
283279
# citation should use original 'ids'.
284280
is_citation_ref = NodeMatcher(nodes.citation_reference, refname=Any)
285-
old_cite_refs: list[nodes.citation_reference] = [*self.node.findall(is_citation_ref)]
286-
new_cite_refs: list[nodes.citation_reference] = [*self.patch.findall(is_citation_ref)]
281+
old_cite_refs = list(is_citation_ref.findall(self.node))
282+
new_cite_refs = list(is_citation_ref.findall(self.patch))
287283
self.compare_references(old_cite_refs, new_cite_refs,
288284
__('inconsistent citation references in translated message.' +
289285
' original: {0}, translated: {1}'))
@@ -549,7 +545,7 @@ def apply(self, **kwargs: Any) -> None:
549545
return
550546

551547
total = translated = 0
552-
for node in self.document.findall(NodeMatcher(translated=Any)): # type: nodes.Element
548+
for node in NodeMatcher(nodes.Element, translated=Any).findall(self.document):
553549
total += 1
554550
if node['translated']:
555551
translated += 1
@@ -588,7 +584,7 @@ def apply(self, **kwargs: Any) -> None:
588584
'True, False, "translated" or "untranslated"')
589585
raise ConfigError(msg)
590586

591-
for node in self.document.findall(NodeMatcher(translated=Any)): # type: nodes.Element
587+
for node in NodeMatcher(nodes.Element, translated=Any).findall(self.document):
592588
if node['translated']:
593589
if add_translated:
594590
node.setdefault('classes', []).append('translated')
@@ -610,7 +606,7 @@ def apply(self, **kwargs: Any) -> None:
610606
return
611607

612608
matcher = NodeMatcher(nodes.inline, translatable=Any)
613-
for inline in list(self.document.findall(matcher)): # type: nodes.inline
609+
for inline in matcher.findall(self.document):
614610
inline.parent.remove(inline)
615611
inline.parent += inline.children
616612

sphinx/util/nodes.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,19 @@
55
import contextlib
66
import re
77
import unicodedata
8-
from typing import TYPE_CHECKING, Any, Callable
8+
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar
99

1010
from docutils import nodes
11+
from docutils.nodes import Node
1112

1213
from sphinx import addnodes
1314
from sphinx.locale import __
1415
from sphinx.util import logging
1516

1617
if TYPE_CHECKING:
17-
from collections.abc import Iterable
18+
from collections.abc import Iterable, Iterator
1819

19-
from docutils.nodes import Element, Node
20+
from docutils.nodes import Element
2021
from docutils.parsers.rst import Directive
2122
from docutils.parsers.rst.states import Inliner
2223
from docutils.statemachine import StringList
@@ -33,7 +34,10 @@
3334
caption_ref_re = explicit_title_re # b/w compat alias
3435

3536

36-
class NodeMatcher:
37+
N = TypeVar("N", bound=Node)
38+
39+
40+
class NodeMatcher(Generic[N]):
3741
"""A helper class for Node.findall().
3842
3943
It checks that the given node is an instance of the specified node-classes and
@@ -43,20 +47,18 @@ class NodeMatcher:
4347
and ``reftype`` attributes::
4448
4549
matcher = NodeMatcher(nodes.reference, refdomain='std', reftype='citation')
46-
doctree.findall(matcher)
50+
matcher.findall(doctree)
4751
# => [<reference ...>, <reference ...>, ...]
4852
4953
A special value ``typing.Any`` matches any kind of node-attributes. For example,
5054
following example searches ``reference`` node having ``refdomain`` attributes::
5155
52-
from __future__ import annotations
53-
from typing import TYPE_CHECKING, Any
5456
matcher = NodeMatcher(nodes.reference, refdomain=Any)
55-
doctree.findall(matcher)
57+
matcher.findall(doctree)
5658
# => [<reference ...>, <reference ...>, ...]
5759
"""
5860

59-
def __init__(self, *node_classes: type[Node], **attrs: Any) -> None:
61+
def __init__(self, *node_classes: type[N], **attrs: Any) -> None:
6062
self.classes = node_classes
6163
self.attrs = attrs
6264

@@ -85,6 +87,14 @@ def match(self, node: Node) -> bool:
8587
def __call__(self, node: Node) -> bool:
8688
return self.match(node)
8789

90+
def findall(self, node: Node) -> Iterator[N]:
91+
"""An alternative to `Node.findall` with improved type safety.
92+
93+
While the `NodeMatcher` object can be used as an argument to `Node.findall`, doing so
94+
confounds type checkers' ability to determine the return type of the iterator.
95+
"""
96+
return node.findall(self)
97+
8898

8999
def get_full_module_name(node: Node) -> str:
90100
"""
@@ -308,7 +318,7 @@ def traverse_translatable_index(
308318
) -> Iterable[tuple[Element, list[tuple[str, str, str, str, str | None]]]]:
309319
"""Traverse translatable index node from a document tree."""
310320
matcher = NodeMatcher(addnodes.index, inline=False)
311-
for node in doctree.findall(matcher): # type: addnodes.index
321+
for node in matcher.findall(doctree):
312322
if 'raw_entries' in node:
313323
entries = node['raw_entries']
314324
else:

sphinx/writers/manpage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(self, document: nodes.document) -> None:
5555

5656
def apply(self, **kwargs: Any) -> None:
5757
matcher = NodeMatcher(nodes.literal, nodes.emphasis, nodes.strong)
58-
for node in list(self.document.findall(matcher)): # type: nodes.TextElement
58+
for node in list(matcher.findall(self.document)):
5959
if any(matcher(subnode) for subnode in node):
6060
pos = node.parent.index(node)
6161
for subnode in reversed(list(node)):

0 commit comments

Comments
 (0)