Skip to content

Commit 450ad63

Browse files
Begin using session-based HTTP requests in the linkcheck builder (#11503)
Co-authored-by: Adam Turner <[email protected]>
1 parent 1cb52d5 commit 450ad63

File tree

3 files changed

+48
-49
lines changed

3 files changed

+48
-49
lines changed

sphinx/builders/linkcheck.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -279,12 +279,16 @@ def __init__(self, config: Config,
279279
self.tls_verify = config.tls_verify
280280
self.tls_cacerts = config.tls_cacerts
281281

282+
self._session = requests._Session()
283+
282284
super().__init__(daemon=True)
283285

284286
def run(self) -> None:
285287
while True:
286288
next_check, hyperlink = self.wqueue.get()
287289
if hyperlink is None:
290+
# An empty hyperlink is a signal to shutdown the worker; cleanup resources here
291+
self._session.close()
288292
break
289293

290294
uri, docname, _docpath, lineno = hyperlink
@@ -346,6 +350,13 @@ def _check(self, docname: str, uri: str, hyperlink: Hyperlink) -> tuple[str, str
346350

347351
return status, info, code
348352

353+
def _retrieval_methods(self,
354+
check_anchors: bool,
355+
anchor: str) -> Iterator[tuple[Callable, dict]]:
356+
if not check_anchors or not anchor:
357+
yield self._session.head, {'allow_redirects': True}
358+
yield self._session.get, {'stream': True}
359+
349360
def _check_uri(self, uri: str, hyperlink: Hyperlink) -> tuple[str, str, int]:
350361
req_url, delimiter, anchor = uri.partition('#')
351362
for rex in self.anchors_ignore if delimiter and anchor else []:
@@ -377,7 +388,7 @@ def _check_uri(self, uri: str, hyperlink: Hyperlink) -> tuple[str, str, int]:
377388
error_message = ''
378389
status_code = -1
379390
response_url = retry_after = ''
380-
for retrieval_method, kwargs in _retrieval_methods(self.check_anchors, anchor):
391+
for retrieval_method, kwargs in self._retrieval_methods(self.check_anchors, anchor):
381392
try:
382393
with retrieval_method(
383394
url=req_url, auth=auth_info,
@@ -508,12 +519,6 @@ def _get_request_headers(
508519
return {}
509520

510521

511-
def _retrieval_methods(check_anchors: bool, anchor: str) -> Iterator[tuple[Callable, dict]]:
512-
if not check_anchors or not anchor:
513-
yield requests.head, {'allow_redirects': True}
514-
yield requests.get, {'stream': True}
515-
516-
517522
def contains_anchor(response: Response, anchor: str) -> bool:
518523
"""Determine if an anchor is contained within an HTTP response."""
519524

sphinx/util/requests.py

Lines changed: 35 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
from __future__ import annotations
44

55
import warnings
6-
from contextlib import contextmanager
7-
from typing import Any, Iterator
6+
from typing import Any
87
from urllib.parse import urlsplit
98

109
import requests
@@ -16,15 +15,6 @@
1615
f'Sphinx/{sphinx.__version__}')
1716

1817

19-
@contextmanager
20-
def ignore_insecure_warning(verify: bool) -> Iterator[None]:
21-
with warnings.catch_warnings():
22-
if not verify:
23-
# ignore InsecureRequestWarning if verify=False
24-
warnings.filterwarnings("ignore", category=InsecureRequestWarning)
25-
yield
26-
27-
2818
def _get_tls_cacert(url: str, certs: str | dict[str, str] | None) -> str | bool:
2919
"""Get additional CA cert for a specific URL."""
3020
if not certs:
@@ -39,41 +29,45 @@ def _get_tls_cacert(url: str, certs: str | dict[str, str] | None) -> str | bool:
3929
return certs.get(hostname, True)
4030

4131

42-
def get(url: str,
43-
_user_agent: str = '',
44-
_tls_info: tuple[bool, str | dict[str, str] | None] = (), # type: ignore[assignment]
45-
**kwargs: Any) -> requests.Response:
46-
"""Sends a HEAD request like requests.head().
32+
def get(url: str, **kwargs: Any) -> requests.Response:
33+
"""Sends a GET request like requests.get().
4734
4835
This sets up User-Agent header and TLS verification automatically."""
49-
headers = kwargs.setdefault('headers', {})
50-
headers.setdefault('User-Agent', _user_agent or _USER_AGENT)
51-
if _tls_info:
52-
tls_verify, tls_cacerts = _tls_info
53-
verify = bool(kwargs.get('verify', tls_verify))
54-
kwargs.setdefault('verify', verify and _get_tls_cacert(url, tls_cacerts))
55-
else:
56-
verify = kwargs.get('verify', True)
36+
with _Session() as session:
37+
return session.get(url, **kwargs)
5738

58-
with ignore_insecure_warning(verify):
59-
return requests.get(url, **kwargs)
6039

61-
62-
def head(url: str,
63-
_user_agent: str = '',
64-
_tls_info: tuple[bool, str | dict[str, str] | None] = (), # type: ignore[assignment]
65-
**kwargs: Any) -> requests.Response:
40+
def head(url: str, **kwargs: Any) -> requests.Response:
6641
"""Sends a HEAD request like requests.head().
6742
6843
This sets up User-Agent header and TLS verification automatically."""
69-
headers = kwargs.setdefault('headers', {})
70-
headers.setdefault('User-Agent', _user_agent or _USER_AGENT)
71-
if _tls_info:
72-
tls_verify, tls_cacerts = _tls_info
73-
verify = bool(kwargs.get('verify', tls_verify))
74-
kwargs.setdefault('verify', verify and _get_tls_cacert(url, tls_cacerts))
75-
else:
76-
verify = kwargs.get('verify', True)
44+
with _Session() as session:
45+
return session.head(url, **kwargs)
7746

78-
with ignore_insecure_warning(verify):
79-
return requests.head(url, **kwargs)
47+
48+
class _Session(requests.Session):
49+
def request( # type: ignore[override]
50+
self, method: str, url: str,
51+
_user_agent: str = '',
52+
_tls_info: tuple[bool, str | dict[str, str] | None] = (), # type: ignore[assignment]
53+
**kwargs: Any,
54+
) -> requests.Response:
55+
"""Sends a request with an HTTP verb and url.
56+
57+
This sets up User-Agent header and TLS verification automatically."""
58+
headers = kwargs.setdefault('headers', {})
59+
headers.setdefault('User-Agent', _user_agent or _USER_AGENT)
60+
if _tls_info:
61+
tls_verify, tls_cacerts = _tls_info
62+
verify = bool(kwargs.get('verify', tls_verify))
63+
kwargs.setdefault('verify', verify and _get_tls_cacert(url, tls_cacerts))
64+
else:
65+
verify = kwargs.get('verify', True)
66+
67+
if verify:
68+
return super().request(method, url, **kwargs)
69+
70+
with warnings.catch_warnings():
71+
# ignore InsecureRequestWarning if verify=False
72+
warnings.filterwarnings("ignore", category=InsecureRequestWarning)
73+
return super().request(method, url, **kwargs)

tests/test_build_linkcheck.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def test_defaults(app):
104104
with http_server(DefaultsHandler):
105105
with ConnectionMeasurement() as m:
106106
app.build()
107-
assert m.connection_count <= 10
107+
assert m.connection_count <= 5
108108

109109
# Text output
110110
assert (app.outdir / 'output.txt').exists()

0 commit comments

Comments
 (0)