diff --git a/src/nb_internals.h b/src/nb_internals.h index f919499e..4a1fbb40 100644 --- a/src/nb_internals.h +++ b/src/nb_internals.h @@ -56,7 +56,7 @@ struct nb_inst { // usually: 24 bytes /// State of the C++ object this instance points to: is it constructed? /// can we use it? - uint32_t state : 2; + uint8_t state : 2; // Values for `state`. Note that the numeric values of these are relied upon // for an optimization in `nb_type_get()`. @@ -70,25 +70,27 @@ struct nb_inst { // usually: 24 bytes * relative offset to a pointer that must be dereferenced to get to the * instance data. 'direct' is 'true' in the former case. */ - uint32_t direct : 1; + uint8_t direct : 1; /// Is the instance data co-located with the Python object? - uint32_t internal : 1; + uint8_t internal : 1; /// Should the destructor be called when this instance is GCed? - uint32_t destruct : 1; + uint8_t destruct : 1; /// Should nanobind call 'operator delete' when this instance is GCed? - uint32_t cpp_delete : 1; - - /// Does this instance hold references to others? (via internals.keep_alive) - uint32_t clear_keep_alive : 1; + uint8_t cpp_delete : 1; /// Does this instance use intrusive reference counting? - uint32_t intrusive : 1; + uint8_t intrusive : 1; + + /// Does this instance hold references to others? (via internals.keep_alive) + /// This may be accessed concurrently to 'state', so it must not be in + /// the same bitfield as 'state'. + uint8_t clear_keep_alive; // That's a lot of unused space. I wonder if there is a good use for it.. - uint32_t unused : 24; + uint16_t unused; }; static_assert(sizeof(nb_inst) == sizeof(PyObject) + sizeof(uint32_t) * 2); diff --git a/tests/test_thread.cpp b/tests/test_thread.cpp index 97e82960..34f6ab98 100644 --- a/tests/test_thread.cpp +++ b/tests/test_thread.cpp @@ -1,4 +1,8 @@ #include +#include + +#include +#include namespace nb = nanobind; using namespace nb::literals; @@ -32,6 +36,11 @@ class ClassWithClassProperty { ClassWithProperty value_; }; +struct AnInt { + int value; + AnInt(int v) : value(v) {} +}; + NB_MODULE(test_thread_ext, m) { nb::class_(m, "Counter") @@ -68,4 +77,17 @@ NB_MODULE(test_thread_ext, m) { new (self) ClassWithClassProperty(std::move(value)); }, nb::arg("value")) .def_prop_ro("prop1", &ClassWithClassProperty::get_prop); + + nb::class_(m, "AnInt") + .def(nb::init()) + .def_rw("value", &AnInt::value); + + std::vector> shared_ints; + for (int i = 0; i < 5; ++i) { + shared_ints.push_back(std::make_shared(i)); + } + m.def("fetch_shared_int", [shared_ints](int i) { + return shared_ints.at(i); + }); + m.def("consume_an_int", [](AnInt* p) { return p->value; }); } diff --git a/tests/test_thread.py b/tests/test_thread.py index 1d5992ee..22bb5c71 100644 --- a/tests/test_thread.py +++ b/tests/test_thread.py @@ -1,3 +1,6 @@ +import random +import threading + import test_thread_ext as t from test_thread_ext import Counter, GlobalData, ClassWithProperty, ClassWithClassProperty from common import parallelize @@ -100,3 +103,16 @@ def f(): _ = c2.prop1.prop2 parallelize(f, n_threads=n_threads) + + +def test08_shared_ptr_threaded_access(n_threads=8): + # Test for keep_alive racing with other fields. + def f(barrier): + i = random.randint(0, 4) + barrier.wait() + p = t.fetch_shared_int(i) + assert t.consume_an_int(p) == i + + for _ in range(100): + barrier = threading.Barrier(n_threads) + parallelize(lambda: f(barrier), n_threads=n_threads)