Skip to content

Commit f26d492

Browse files
authored
[config] protect is_serializable against circular references (#12196)
1 parent 885818b commit f26d492

File tree

2 files changed

+208
-5
lines changed

2 files changed

+208
-5
lines changed

sphinx/config.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,30 @@ class ConfigValue(NamedTuple):
5151
rebuild: _ConfigRebuild
5252

5353

54-
def is_serializable(obj: Any) -> bool:
54+
def is_serializable(obj: object, *, _recursive_guard: frozenset[int] = frozenset()) -> bool:
5555
"""Check if object is serializable or not."""
5656
if isinstance(obj, UNSERIALIZABLE_TYPES):
5757
return False
58-
elif isinstance(obj, dict):
58+
59+
# use id() to handle un-hashable objects
60+
if id(obj) in _recursive_guard:
61+
return True
62+
63+
if isinstance(obj, dict):
64+
guard = _recursive_guard | {id(obj)}
5965
for key, value in obj.items():
60-
if not is_serializable(key) or not is_serializable(value):
66+
if (
67+
not is_serializable(key, _recursive_guard=guard)
68+
or not is_serializable(value, _recursive_guard=guard)
69+
):
6170
return False
62-
elif isinstance(obj, (list, tuple, set)):
63-
return all(map(is_serializable, obj))
71+
elif isinstance(obj, (list, tuple, set, frozenset)):
72+
guard = _recursive_guard | {id(obj)}
73+
return all(is_serializable(item, _recursive_guard=guard) for item in obj)
6474

75+
# if an issue occurs for a non-serializable type, pickle will complain
76+
# since the object is likely coming from a third-party extension (we
77+
# natively expect 'simple' types and not weird ones)
6578
return True
6679

6780

tests/test_config/test_config.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
"""Test the sphinx.config.Config class."""
2+
from __future__ import annotations
3+
24
import pickle
35
import time
6+
from collections import Counter
47
from pathlib import Path
8+
from typing import TYPE_CHECKING
59
from unittest import mock
610

711
import pytest
@@ -14,10 +18,51 @@
1418
_Opt,
1519
check_confval_types,
1620
correct_copyright_year,
21+
is_serializable,
1722
)
1823
from sphinx.deprecation import RemovedInSphinx90Warning
1924
from sphinx.errors import ConfigError, ExtensionError, VersionRequirementError
2025

26+
if TYPE_CHECKING:
27+
from collections.abc import Iterable
28+
from typing import Union
29+
30+
CircularList = list[Union[int, 'CircularList']]
31+
CircularDict = dict[str, Union[int, 'CircularDict']]
32+
33+
34+
def check_is_serializable(subject: object, *, circular: bool) -> None:
35+
assert is_serializable(subject)
36+
37+
if circular:
38+
class UselessGuard(frozenset[int]):
39+
def __or__(self, other: object, /) -> UselessGuard:
40+
# do nothing
41+
return self
42+
43+
def union(self, *args: Iterable[object]) -> UselessGuard:
44+
# do nothing
45+
return self
46+
47+
# check that without recursive guards, a recursion error occurs
48+
with pytest.raises(RecursionError):
49+
assert is_serializable(subject, _recursive_guard=UselessGuard())
50+
51+
52+
def test_is_serializable() -> None:
53+
subject = [1, [2, {3, 'a'}], {'x': {'y': frozenset((4, 5))}}]
54+
check_is_serializable(subject, circular=False)
55+
56+
a, b = [1], [2] # type: (CircularList, CircularList)
57+
a.append(b)
58+
b.append(a)
59+
check_is_serializable(a, circular=True)
60+
check_is_serializable(b, circular=True)
61+
62+
x: CircularDict = {'a': 1, 'b': {'c': 1}}
63+
x['b'] = x
64+
check_is_serializable(x, circular=True)
65+
2166

2267
def test_config_opt_deprecated(recwarn):
2368
opt = _Opt('default', '', ())
@@ -102,6 +147,151 @@ def test_config_pickle_protocol(tmp_path, protocol: int):
102147
assert repr(config) == repr(pickled_config)
103148

104149

150+
def test_config_pickle_circular_reference_in_list():
151+
a, b = [1], [2] # type: (CircularList, CircularList)
152+
a.append(b)
153+
b.append(a)
154+
155+
check_is_serializable(a, circular=True)
156+
check_is_serializable(b, circular=True)
157+
158+
config = Config()
159+
config.add('a', [], '', types=list)
160+
config.add('b', [], '', types=list)
161+
config.a, config.b = a, b
162+
163+
actual = pickle.loads(pickle.dumps(config))
164+
assert isinstance(actual.a, list)
165+
check_is_serializable(actual.a, circular=True)
166+
167+
assert isinstance(actual.b, list)
168+
check_is_serializable(actual.b, circular=True)
169+
170+
assert actual.a[0] == 1
171+
assert actual.a[1][0] == 2
172+
assert actual.a[1][1][0] == 1
173+
assert actual.a[1][1][1][0] == 2
174+
175+
assert actual.b[0] == 2
176+
assert actual.b[1][0] == 1
177+
assert actual.b[1][1][0] == 2
178+
assert actual.b[1][1][1][0] == 1
179+
180+
assert len(actual.a) == 2
181+
assert len(actual.a[1]) == 2
182+
assert len(actual.a[1][1]) == 2
183+
assert len(actual.a[1][1][1]) == 2
184+
assert len(actual.a[1][1][1][1]) == 2
185+
186+
assert len(actual.b) == 2
187+
assert len(actual.b[1]) == 2
188+
assert len(actual.b[1][1]) == 2
189+
assert len(actual.b[1][1][1]) == 2
190+
assert len(actual.b[1][1][1][1]) == 2
191+
192+
def check(
193+
u: list[list[object] | int],
194+
v: list[list[object] | int],
195+
*,
196+
counter: Counter[type, int] | None = None,
197+
guard: frozenset[int] = frozenset(),
198+
) -> Counter[type, int]:
199+
counter = Counter() if counter is None else counter
200+
201+
if id(u) in guard and id(v) in guard:
202+
return counter
203+
204+
if isinstance(u, int):
205+
assert v.__class__ is u.__class__
206+
assert u == v
207+
counter[type(u)] += 1
208+
return counter
209+
210+
assert isinstance(u, list)
211+
assert v.__class__ is u.__class__
212+
assert len(u) == len(v)
213+
214+
for u_i, v_i in zip(u, v):
215+
counter[type(u)] += 1
216+
check(u_i, v_i, counter=counter, guard=guard | {id(u), id(v)})
217+
218+
return counter
219+
220+
counter = check(actual.a, a)
221+
# check(actual.a, a)
222+
# check(actual.a[0], a[0]) -> ++counter[dict]
223+
# ++counter[int] (a[0] is an int)
224+
# check(actual.a[1], a[1]) -> ++counter[dict]
225+
# check(actual.a[1][0], a[1][0]) -> ++counter[dict]
226+
# ++counter[int] (a[1][0] is an int)
227+
# check(actual.a[1][1], a[1][1]) -> ++counter[dict]
228+
# recursive guard since a[1][1] == a
229+
assert counter[type(a[0])] == 2
230+
assert counter[type(a[1])] == 4
231+
232+
# same logic as above
233+
counter = check(actual.b, b)
234+
assert counter[type(b[0])] == 2
235+
assert counter[type(b[1])] == 4
236+
237+
238+
def test_config_pickle_circular_reference_in_dict():
239+
x: CircularDict = {'a': 1, 'b': {'c': 1}}
240+
x['b'] = x
241+
check_is_serializable(x, circular=True)
242+
243+
config = Config()
244+
config.add('x', [], '', types=dict)
245+
config.x = x
246+
247+
actual = pickle.loads(pickle.dumps(config))
248+
check_is_serializable(actual.x, circular=True)
249+
assert isinstance(actual.x, dict)
250+
251+
assert actual.x['a'] == 1
252+
assert actual.x['b']['a'] == 1
253+
254+
assert len(actual.x) == 2
255+
assert len(actual.x['b']) == 2
256+
assert len(actual.x['b']['b']) == 2
257+
258+
def check(
259+
u: dict[str, dict[str, object] | int],
260+
v: dict[str, dict[str, object] | int],
261+
*,
262+
counter: Counter[type, int] | None = None,
263+
guard: frozenset[int] = frozenset(),
264+
) -> Counter:
265+
counter = Counter() if counter is None else counter
266+
267+
if id(u) in guard and id(v) in guard:
268+
return counter
269+
270+
if isinstance(u, int):
271+
assert v.__class__ is u.__class__
272+
assert u == v
273+
counter[type(u)] += 1
274+
return counter
275+
276+
assert isinstance(u, dict)
277+
assert v.__class__ is u.__class__
278+
assert len(u) == len(v)
279+
280+
for u_i, v_i in zip(u, v):
281+
counter[type(u)] += 1
282+
check(u[u_i], v[v_i], counter=counter, guard=guard | {id(u), id(v)})
283+
return counter
284+
285+
counters = check(actual.x, x, counter=Counter())
286+
# check(actual.x, x)
287+
# check(actual.x['a'], x['a']) -> ++counter[dict]
288+
# ++counter[int] (x['a'] is an int)
289+
# check(actual.x['b'], x['b']) -> ++counter[dict]
290+
# recursive guard since x['b'] == x
291+
assert counters[type(x['a'])] == 1
292+
assert counters[type(x['b'])] == 2
293+
294+
105295
def test_extension_values():
106296
config = Config()
107297

0 commit comments

Comments
 (0)