diff --git a/src/runtime/driver/handle.rs b/src/runtime/driver/handle.rs index 115f780d..ed4a737a 100644 --- a/src/runtime/driver/handle.rs +++ b/src/runtime/driver/handle.rs @@ -34,10 +34,32 @@ pub(crate) struct WeakHandle { inner: Weak>, } +struct ThreadParker; +impl std::future::Future for ThreadParker { + type Output = (); + fn poll( + self: std::pin::Pin<&mut Self>, + ctx: &mut std::task::Context<'_>, + ) -> std::task::Poll<::Output> { + ctx.waker().clone().wake(); + std::task::Poll::Pending + } +} + impl Handle { - pub(crate) fn new(b: &crate::Builder) -> io::Result { + pub(crate) fn new( + b: &crate::Builder, + tokio_rt: &tokio::runtime::Runtime, + local: &tokio::task::LocalSet, + ) -> io::Result { + let driver = Driver::new(b)?; + let params = driver.uring.params(); + if params.is_setup_iopoll() && !params.is_setup_sqpoll() { + let _guard = tokio_rt.enter(); + local.spawn_local(ThreadParker {}); + } Ok(Self { - inner: Rc::new(RefCell::new(Driver::new(b)?)), + inner: Rc::new(RefCell::new(driver)), }) } diff --git a/src/runtime/driver/mod.rs b/src/runtime/driver/mod.rs index f57605d6..93097c73 100644 --- a/src/runtime/driver/mod.rs +++ b/src/runtime/driver/mod.rs @@ -19,7 +19,7 @@ pub(crate) struct Driver { ops: Ops, /// IoUring bindings - uring: IoUring, + pub(crate) uring: IoUring, /// Reference to the currently registered buffers. /// Ensures that the buffers are not dropped until @@ -40,6 +40,8 @@ impl Driver { pub(crate) fn new(b: &crate::Builder) -> io::Result { let uring = b.urb.build(b.entries)?; + if uring.params().is_setup_iopoll() && !uring.params().is_setup_sqpoll() {} + Ok(Driver { ops: Ops::new(), uring, diff --git a/src/runtime/mod.rs b/src/runtime/mod.rs index 369c060b..19d18e83 100644 --- a/src/runtime/mod.rs +++ b/src/runtime/mod.rs @@ -81,7 +81,7 @@ impl Runtime { let tokio_rt = ManuallyDrop::new(rt); let local = ManuallyDrop::new(LocalSet::new()); - let driver = driver::Handle::new(b)?; + let driver = driver::Handle::new(b, &tokio_rt, &local)?; start_uring_wakes_task(&tokio_rt, &local, driver.clone()); diff --git a/tests/fs_file.rs b/tests/fs_file.rs index 6ec14d43..36596c33 100644 --- a/tests/fs_file.rs +++ b/tests/fs_file.rs @@ -315,6 +315,33 @@ fn basic_fallocate() { }); } +#[test] +fn iopoll_without_sqpoll() { + use std::os::unix::fs::OpenOptionsExt; + let mut builder = tokio_uring::builder(); + builder.uring_builder(&tokio_uring::uring_builder().setup_iopoll()); + let runtime = tokio_uring::Runtime::new(&builder).unwrap(); + let tmp = tempfile(); + runtime.block_on(async { + let file = std::fs::OpenOptions::new() + .write(true) + .custom_flags(libc::O_DIRECT) + .open(tmp.path()) + .unwrap(); + let file = tokio_uring::fs::File::from_std(file); + + let layout = std::alloc::Layout::from_size_align(512, 512).unwrap(); + let buf = unsafe { + let raw = std::alloc::alloc(layout); + std::ptr::copy("asdf".as_ptr(), raw, 4); + std::slice::from_raw_parts(raw, 512) + }; + + let res = file.write_at(buf, 0).submit().await.0.unwrap(); + assert_eq!(res, 512); + }); +} + fn tempfile() -> NamedTempFile { NamedTempFile::new().unwrap() }