Skip to content

Commit ddb64e8

Browse files
authored
Python node event mutex (dora-rs#1244)
Allows concurrent read/write access through a node (to address dora-rs#748). "Already borrowed" seems to happen when there are concurrent calls to read/write functions because both currently require an exclusive `Node` ref. 1. Updates `DelayedCleanup.get_mut` and callers to use `&self` instead of `&mut self`. 2. Adds a mutex to `Events.inner` so callers can use `&self` instead of `&mut self`. 3. Adds an example dataflow that fails on main (with the "Already borrowed" error) but works with this change
2 parents e14d802 + 8965160 commit ddb64e8

File tree

5 files changed

+102
-25
lines changed

5 files changed

+102
-25
lines changed

apis/python/node/src/lib.rs

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::env::current_dir;
44
use std::path::PathBuf;
55
use std::sync::Arc;
66
use std::time::Duration;
7+
use tokio::sync::Mutex;
78

89
use arrow::pyarrow::{FromPyArrow, ToPyArrow};
910
use dora_download::download_file;
@@ -135,7 +136,7 @@ impl Node {
135136

136137
Ok(Node {
137138
events: Events {
138-
inner: EventsInner::Dora(events),
139+
inner: Arc::new(Mutex::new(EventsInner::Dora(events))),
139140
_cleanup_handle: cleanup_handle,
140141
},
141142
dataflow_id,
@@ -167,7 +168,7 @@ impl Node {
167168
/// :rtype: dict
168169
#[pyo3(signature = (timeout=None))]
169170
#[allow(clippy::should_implement_trait)]
170-
pub fn next(&mut self, py: Python, timeout: Option<f32>) -> PyResult<Option<Py<PyDict>>> {
171+
pub fn next(&self, py: Python, timeout: Option<f32>) -> PyResult<Option<Py<PyDict>>> {
171172
let event = py.allow_threads(|| self.events.recv(timeout.map(Duration::from_secs_f32)));
172173
if let Some(event) = event {
173174
let dict = event
@@ -190,7 +191,7 @@ impl Node {
190191
///
191192
/// :rtype: list[dict]
192193
#[allow(clippy::should_implement_trait)]
193-
pub fn drain(&mut self, py: Python) -> PyResult<Vec<Py<PyDict>>> {
194+
pub fn drain(&self, py: Python) -> PyResult<Vec<Py<PyDict>>> {
194195
let events = self
195196
.events
196197
.drain()
@@ -253,7 +254,7 @@ impl Node {
253254
/// :rtype: dict
254255
#[pyo3(signature = (timeout=None))]
255256
#[allow(clippy::should_implement_trait)]
256-
pub async fn recv_async(&mut self, timeout: Option<f32>) -> PyResult<Option<Py<PyDict>>> {
257+
pub async fn recv_async(&self, timeout: Option<f32>) -> PyResult<Option<Py<PyDict>>> {
257258
let event = self
258259
.events
259260
.recv_async_timeout(timeout.map(Duration::from_secs_f32))
@@ -264,6 +265,7 @@ impl Node {
264265
let dict = event
265266
.to_py_dict(py)
266267
.context("Could not convert event into a dict")?;
268+
267269
Ok(Some(dict))
268270
})
269271
} else {
@@ -284,7 +286,7 @@ impl Node {
284286
/// Default behaviour is to timeout after 2 seconds.
285287
///
286288
/// :rtype: dict
287-
pub fn __next__(&mut self, py: Python) -> PyResult<Option<Py<PyDict>>> {
289+
pub fn __next__(&self, py: Python) -> PyResult<Option<Py<PyDict>>> {
288290
self.next(py, None)
289291
}
290292

@@ -324,7 +326,7 @@ impl Node {
324326
/// :rtype: None
325327
#[pyo3(signature = (output_id, data, metadata=None))]
326328
pub fn send_output(
327-
&mut self,
329+
&self,
328330
output_id: String,
329331
data: PyObject,
330332
metadata: Option<Bound<'_, PyDict>>,
@@ -356,7 +358,7 @@ impl Node {
356358
/// This method returns the parsed dataflow YAML file.
357359
///
358360
/// :rtype: dict
359-
pub fn dataflow_descriptor(&mut self, py: Python) -> eyre::Result<PyObject> {
361+
pub fn dataflow_descriptor(&self, py: Python) -> eyre::Result<PyObject> {
360362
Ok(
361363
pythonize::pythonize(py, &self.node.get_mut().dataflow_descriptor()?)
362364
.map(|x| x.unbind())?,
@@ -366,7 +368,7 @@ impl Node {
366368
/// Returns the node configuration.
367369
///
368370
/// :rtype: dict
369-
pub fn node_config(&mut self, py: Python) -> eyre::Result<PyObject> {
371+
pub fn node_config(&self, py: Python) -> eyre::Result<PyObject> {
370372
Ok(pythonize::pythonize(py, &self.node.get_mut().node_config()).map(|x| x.unbind())?)
371373
}
372374

@@ -382,10 +384,7 @@ impl Node {
382384
///
383385
/// :type subscription: dora.Ros2Subscription
384386
/// :rtype: None
385-
pub fn merge_external_events(
386-
&mut self,
387-
subscription: &mut Ros2Subscription,
388-
) -> eyre::Result<()> {
387+
pub fn merge_external_events(&self, subscription: &mut Ros2Subscription) -> eyre::Result<()> {
389388
let subscription = subscription.into_stream()?;
390389
let stream = futures::stream::poll_fn(move |cx| {
391390
let s = subscription.as_stream().map(|item| {
@@ -404,12 +403,13 @@ impl Node {
404403
});
405404

406405
// take out the event stream and temporarily replace it with a dummy
406+
let mut inner = self.events.inner.blocking_lock();
407407
let events = std::mem::replace(
408-
&mut self.events.inner,
408+
&mut *inner,
409409
EventsInner::Merged(Box::new(futures::stream::empty())),
410410
);
411411
// update self.events with the merged stream
412-
self.events.inner = EventsInner::Merged(events.merge_external_send(Box::pin(stream)));
412+
*inner = EventsInner::Merged(events.merge_external_send(Box::pin(stream)));
413413

414414
Ok(())
415415
}
@@ -424,13 +424,14 @@ fn err_to_pyany(err: eyre::Report, gil: Python<'_>) -> Py<PyAny> {
424424
}
425425

426426
struct Events {
427-
inner: EventsInner,
427+
inner: Arc<Mutex<EventsInner>>,
428428
_cleanup_handle: NodeCleanupHandle,
429429
}
430430

431431
impl Events {
432-
fn recv(&mut self, timeout: Option<Duration>) -> Option<PyEvent> {
433-
let event = match &mut self.inner {
432+
fn recv(&self, timeout: Option<Duration>) -> Option<PyEvent> {
433+
let mut inner = self.inner.blocking_lock();
434+
let event = match &mut *inner {
434435
EventsInner::Dora(events) => match timeout {
435436
Some(timeout) => events.recv_timeout(timeout).map(MergedEvent::Dora),
436437
None => events.recv().map(MergedEvent::Dora),
@@ -440,8 +441,9 @@ impl Events {
440441
event.map(|event| PyEvent { event })
441442
}
442443

443-
fn try_recv(&mut self) -> Result<PyEvent, TryRecvError> {
444-
let event = match &mut self.inner {
444+
fn try_recv(&self) -> Result<PyEvent, TryRecvError> {
445+
let mut inner = self.inner.blocking_lock();
446+
let event = match &mut *inner {
445447
EventsInner::Dora(events) => events.try_recv().map(MergedEvent::Dora),
446448
EventsInner::Merged(_events) => {
447449
todo!("try_recv on external event stream is not yet implemented!")
@@ -450,8 +452,9 @@ impl Events {
450452
event.map(|event| PyEvent { event })
451453
}
452454

453-
async fn recv_async_timeout(&mut self, timeout: Option<Duration>) -> Option<PyEvent> {
454-
let event = match &mut self.inner {
455+
async fn recv_async_timeout(&self, timeout: Option<Duration>) -> Option<PyEvent> {
456+
let mut inner = self.inner.lock().await;
457+
let event = match &mut *inner {
455458
EventsInner::Dora(events) => match timeout {
456459
Some(timeout) => events
457460
.recv_async_timeout(timeout)
@@ -464,8 +467,9 @@ impl Events {
464467
event.map(|event| PyEvent { event })
465468
}
466469

467-
fn drain(&mut self) -> Option<Vec<PyEvent>> {
468-
match &mut self.inner {
470+
fn drain(&self) -> Option<Vec<PyEvent>> {
471+
let mut inner = self.inner.blocking_lock();
472+
match &mut *inner {
469473
EventsInner::Dora(events) => match events.drain() {
470474
Some(items) => {
471475
return Some(
@@ -485,7 +489,8 @@ impl Events {
485489
}
486490

487491
fn is_empty(&self) -> bool {
488-
match &self.inner {
492+
let inner = self.inner.blocking_lock();
493+
match &*inner {
489494
EventsInner::Dora(events) => events.is_empty(),
490495
EventsInner::Merged(_events) => {
491496
todo!("is_empty on external event stream is not yet implemented!")

apis/python/operator/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ impl<T> DelayedCleanup<T> {
4040
CleanupHandle(self.0.clone())
4141
}
4242

43-
pub fn get_mut(&mut self) -> std::sync::MutexGuard<T> {
43+
pub fn get_mut(&self) -> std::sync::MutexGuard<T> {
4444
self.0.try_lock().expect("failed to lock DelayedCleanup")
4545
}
4646
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
nodes:
2+
- id: node1
3+
path: ./receive_data.py
4+
inputs:
5+
data: node2/data
6+
outputs:
7+
- data
8+
9+
- id: node2
10+
path: ./receive_data.py
11+
inputs:
12+
data: node1/data
13+
outputs:
14+
- data
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
[project]
2+
name = "dora-test"
3+
version = "0.1.0"
4+
description = "Add your description here"
5+
readme = "README.md"
6+
requires-python = ">=3.12"
7+
dependencies = [
8+
"asyncio>=4.0.0",
9+
"dora-rs-cli>=0.3.13",
10+
"numpy>=2.3.5",
11+
]
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from dora import Node
2+
3+
import logging
4+
import threading
5+
import time
6+
7+
import numpy as np
8+
import pyarrow as pa
9+
10+
11+
def read_data_task(node, log):
12+
"""Task that reads incoming events."""
13+
while (event := node.next()) is not None:
14+
if event["type"] == "INPUT":
15+
print(f"info {event['value'].to_numpy()}")
16+
del event
17+
log.log(logging.INFO, "read_data_task done!")
18+
19+
20+
def publish_task(node, log):
21+
"""Task that publishes to a topic."""
22+
while True:
23+
time.sleep(1) # Publish every 1s
24+
now = time.perf_counter_ns()
25+
node.send_output("data", pa.array([np.uint64(now)]))
26+
27+
28+
def main():
29+
node = Node()
30+
log = logging.getLogger(__name__)
31+
32+
# Create thread for read task
33+
read_thread = threading.Thread(target=read_data_task, args=(node, log))
34+
read_thread.start()
35+
36+
# Run publish task in a daemon thread (so it doesn't block main thread)
37+
publish_thread = threading.Thread(target=publish_task, args=(node, log), daemon=True)
38+
publish_thread.start()
39+
40+
# Wait for read thread to complete
41+
read_thread.join()
42+
43+
log.log(logging.INFO, "done!")
44+
45+
46+
if __name__ == "__main__":
47+
main()

0 commit comments

Comments
 (0)