Skip to content

Commit 27ba245

Browse files
authored
Fix crash from concurrent nb::make_iterator<> under free-threading. (#832)
The `State` class used by make_iterator is constructed lazily, but without locking it is possible for the caller to crash when the class type is only partially constructed. This PR adds an ft_mutex around the binding of the State type. My initial instinct was to use an nb_object_guard on `scope`. Unfortunately this doesn't work; I suspect PyEval_SaveThread() is called during class binding and that releases the outer critical section.
1 parent 3239284 commit 27ba245

File tree

4 files changed

+52
-45
lines changed

4 files changed

+52
-45
lines changed

include/nanobind/make_iterator.h

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -71,27 +71,30 @@ typed<iterator, ValueType> make_iterator_impl(handle scope, const char *name,
7171
"make_iterator_impl(): the generated __next__ would copy elements, so the "
7272
"element type must be copy-constructible");
7373

74-
if (!type<State>().is_valid()) {
75-
class_<State>(scope, name)
76-
.def("__iter__", [](handle h) { return h; })
77-
.def("__next__",
78-
[](State &s) -> ValueType {
79-
if (!s.first_or_done)
80-
++s.it;
81-
else
82-
s.first_or_done = false;
83-
84-
if (s.it == s.end) {
85-
s.first_or_done = true;
86-
throw stop_iteration();
87-
}
88-
89-
return Access()(s.it);
90-
},
91-
std::forward<Extra>(extra)...,
92-
Policy);
74+
{
75+
static ft_mutex mu;
76+
ft_lock_guard lock(mu);
77+
if (!type<State>().is_valid()) {
78+
class_<State>(scope, name)
79+
.def("__iter__", [](handle h) { return h; })
80+
.def("__next__",
81+
[](State &s) -> ValueType {
82+
if (!s.first_or_done)
83+
++s.it;
84+
else
85+
s.first_or_done = false;
86+
87+
if (s.it == s.end) {
88+
s.first_or_done = true;
89+
throw stop_iteration();
90+
}
91+
92+
return Access()(s.it);
93+
},
94+
std::forward<Extra>(extra)...,
95+
Policy);
96+
}
9397
}
94-
9598
return borrow<typed<iterator, ValueType>>(cast(State{
9699
std::forward<Iterator>(first), std::forward<Sentinel>(last), true }));
97100
}

tests/common.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import platform
22
import gc
33
import pytest
4+
import threading
45

56
is_pypy = platform.python_implementation() == 'PyPy'
67
is_darwin = platform.system() == 'Darwin'
@@ -17,3 +18,25 @@ def collect() -> None:
1718

1819
xfail_on_pypy_darwin = pytest.mark.xfail(
1920
is_pypy and is_darwin, reason="This test for some reason fails on PyPy/Darwin")
21+
22+
23+
# Helper function to parallelize execution of a function. We intentionally
24+
# don't use the Python threads pools here to have threads shut down / start
25+
# between test cases.
26+
def parallelize(func, n_threads):
27+
barrier = threading.Barrier(n_threads)
28+
result = [None]*n_threads
29+
30+
def wrapper(i):
31+
barrier.wait()
32+
result[i] = func()
33+
34+
workers = []
35+
for i in range(n_threads):
36+
t = threading.Thread(target=wrapper, args=(i,))
37+
t.start()
38+
workers.append(t)
39+
40+
for worker in workers:
41+
worker.join()
42+
return result

tests/test_make_iterator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import test_make_iterator_ext as t
2-
import pytest
2+
from common import parallelize
33

44
data = [
55
{},
@@ -30,6 +30,10 @@ def test03_items_iterator():
3030
assert sorted(list(m.items_l())) == sorted(list(d.items()))
3131

3232

33+
def test03_items_iterator_parallel(n_threads=8):
34+
parallelize(test03_items_iterator, n_threads=n_threads)
35+
36+
3337
def test04_passthrough_iterator():
3438
for d in data:
3539
m = t.StringMap(d)

tests/test_thread.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,6 @@
11
import test_thread_ext as t
22
from test_thread_ext import Counter
3-
4-
import threading
5-
6-
# Helper function to parallelize execution of a function. We intentionally
7-
# don't use the Python threads pools here to have threads shut down / start
8-
# between test cases.
9-
def parallelize(func, n_threads):
10-
barrier = threading.Barrier(n_threads)
11-
result = [None]*n_threads
12-
13-
def wrapper(i):
14-
barrier.wait()
15-
result[i] = func()
16-
17-
workers = []
18-
for i in range(n_threads):
19-
t = threading.Thread(target=wrapper, args=(i,))
20-
t.start()
21-
workers.append(t)
22-
23-
for worker in workers:
24-
worker.join()
25-
return result
26-
3+
from common import parallelize
274

285
def test01_object_creation(n_threads=8):
296
# This test hammers 'inst_c2p' from multiple threads, and

0 commit comments

Comments
 (0)