Skip to content

Commit 24eb900

Browse files
committed
Solve FilePacket lifetime maybe longer than poller issue
1 parent 5e9aad7 commit 24eb900

File tree

6 files changed

+744
-406
lines changed

6 files changed

+744
-406
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ features = [
5353
"Win32_System_LibraryLoader",
5454
"Win32_System_Threading",
5555
"Win32_System_WindowsProgramming",
56+
"Win32_System_Pipes",
5657
]
5758

5859
[target.'cfg(target_os = "hermit")'.dependencies.hermit-abi]
@@ -71,4 +72,3 @@ signal-hook = "0.3.17"
7172

7273
[target.'cfg(windows)'.dev-dependencies]
7374
tempfile = "3.7"
74-
windows-sys = { version = "0.60", features = ["Win32_System_Pipes"] }

src/iocp/afd.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -468,13 +468,14 @@ unsafe impl<T> Completion for IoStatusBlock<T> {
468468
}
469469

470470
impl<T: FileOverlapped> FileOverlapped for IoStatusBlock<T> {
471+
#[inline]
471472
fn file_read_offset() -> usize {
472473
T::file_read_offset() + std::mem::offset_of!(IoStatusBlock<T>, data)
473474
}
474475

476+
#[inline]
475477
fn file_write_offset() -> usize {
476-
let data_offset = std::mem::offset_of!(IoStatusBlock<T>, data);
477-
T::file_write_offset() + data_offset
478+
T::file_write_offset() + std::mem::offset_of!(IoStatusBlock<T>, data)
478479
}
479480
}
480481

src/iocp/mod.rs

Lines changed: 207 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ use std::cell::UnsafeCell;
5050
use std::collections::hash_map::{Entry, HashMap};
5151
use std::ffi::c_void;
5252
use std::marker::PhantomPinned;
53-
use std::mem::{forget, MaybeUninit};
53+
use std::mem::{forget, ManuallyDrop, MaybeUninit};
5454
use std::os::windows::io::{
5555
AsHandle, AsRawHandle, AsRawSocket, BorrowedHandle, BorrowedSocket, RawHandle, RawSocket,
5656
};
@@ -94,7 +94,7 @@ pub(super) struct Poller {
9494
/// The state of the waitable handles registered with this poller.
9595
waitables: RwLock<HashMap<RawHandle, Packet>>,
9696

97-
/// The state of the waitable handles registered with this poller.
97+
/// The state of the overlapped files registered with this poller.
9898
files: RwLock<HashMap<RawHandle, Packet>>,
9999

100100
/// Sockets with pending updates.
@@ -435,16 +435,31 @@ impl Poller {
435435
}
436436

437437
/// Add a file to the poller.
438+
///
438439
/// File handle work on PollMode::Edge mode. The IOCP continue to poll the events unitl
439440
/// the file is closed. The caller must use the overlapped pointer return in IocpFilePacket
440-
/// as overlapped paramter for I/O operation. The Packet do not need to increase Arc count because
441-
/// the call can trigger events through I/O operation without update intrest events as long as the
442-
/// file handle has been registered with the IOCP. So the Packet lifetime is ended with calling [`remove_file`].
443-
/// Any I/O operation using return overlapped pointer return in IocpFilePacket is undefined behavior.
441+
/// as overlapped paramter for I/O operation. The Packet need to increase Arc count every time the I/O operation
442+
/// is performed success (return TRUE or FALSE with ERROR_IO_PENDING in last error), otherwise the Arc count do
443+
/// not need to increase if I/O operation fail because [`IocpFilePacket`] can exist after the poller is dropped.
444+
/// And I/O operation still be valid with the overlapped pointer after the poller is dropped. [`FileOverlappedConverter`]
445+
/// can help to manage the Arc count to avoid memory leak.
446+
///
447+
/// Normally, the caller use I/O helper function like [`read_file_overlapped`], [`write_file_overlapped`] or
448+
/// [`connect_named_pipe_overlapped`] to perform I/O operation to avoid the complexity of managing the Arc count.
449+
///
450+
/// [`read_file_overlapped`]: crate::os::iocp::read_file_overlapped
451+
/// [`write_file_overlapped`]: crate::os::iocp::write_file_overlapped
452+
/// [`connect_named_pipe_overlapped`]: crate::os::iocp::connect_named_pipe_overlapped
453+
///
454+
/// The call can trigger events through I/O operation without update intrest events as long as the
455+
/// file handle has been registered with the IOCP. The Packet lifetime is ended with conditions: [`remove_file`]
456+
/// is called, I/O operation is polled, and [`IocpFilePacket`] is dropped.
457+
///
458+
/// IocpFilePacket will return both read and write overlapped pointer through [`FileOverlappedConverter::as_ptr()`]
459+
/// no matter what intrest events are.
444460
///
445-
/// IocpFilePacket will return both read and write overlapped pointer no matter what intrest events are.
446-
/// The caller need to use the correct overlapped pointer for I/O operation. Such as: the read overlapped
447-
/// pointer can be used for read operations, and the write overlapped pointer can be used for write operations.
461+
/// The caller need to use the correct overlapped converter for I/O operation. Such as: the read overlapped
462+
/// converter can be used for read operations, and the write overlapped converter can be used for write operations.
448463
pub(super) fn add_file(
449464
&self,
450465
handle: RawHandle,
@@ -485,24 +500,57 @@ impl Poller {
485500
}
486501
}
487502

488-
let (read, write, file_handle) = match handle_state.as_ref().data().project_ref() {
489-
PacketInnerProj::File {
490-
read,
491-
write,
492-
handle,
493-
} => (read.get(), write.get(), handle),
494-
_ => unreachable!("PacketInner should always be File here"),
495-
};
503+
let read_ptr;
504+
let write_ptr;
505+
{
506+
let (read, write, file_handle) = match handle_state.as_ref().data().project_ref() {
507+
PacketInnerProj::File {
508+
read,
509+
write,
510+
handle,
511+
} => (read.get(), write.get(), handle),
512+
_ => unreachable!("PacketInner should always be File here"),
513+
};
496514

497-
let file_state = lock!(file_handle.lock());
498-
// Register the file handle with the I/O completion port.
499-
self.port
500-
.register(&*file_state, true, port::CompletionKeyType::File)?;
515+
let file_state = lock!(file_handle.lock());
516+
// Register the file handle with the I/O completion port.
517+
self.port
518+
.register(&*file_state, true, port::CompletionKeyType::File)?;
519+
520+
read_ptr = unsafe { (*read).as_ptr() };
521+
write_ptr = unsafe { (*write).as_ptr() };
522+
}
501523

502-
let iocp_packet = unsafe { IocpFilePacket::new((*read).as_ptr(), (*write).as_ptr()) };
524+
let iocp_packet =
525+
unsafe { IocpFilePacket::new(read_ptr, write_ptr, PacketWrapper(handle_state)) };
503526
Ok(iocp_packet)
504527
}
505528

529+
pub(super) fn modify_file(&self, handle: RawHandle, interest: Event) -> io::Result<()> {
530+
#[cfg(feature = "tracing")]
531+
tracing::trace!(
532+
"modify_file: handle={:?}, file={:p}, ev={:?}",
533+
self.port,
534+
handle,
535+
interest
536+
);
537+
538+
// Get a reference to the source.
539+
let source = {
540+
let sources = lock!(self.files.read());
541+
542+
sources
543+
.get(&handle)
544+
.cloned()
545+
.ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?
546+
};
547+
548+
// Set the new event.
549+
source.as_ref().set_events(interest, PollMode::Edge);
550+
551+
Ok(())
552+
}
553+
506554
/// Remove a file from the poller.
507555
pub(super) fn remove_file(&self, handle: RawHandle) -> io::Result<()> {
508556
#[cfg(feature = "tracing")]
@@ -835,7 +883,8 @@ impl CompletionPacket {
835883
/// It needs to be pinned, since it contains data that is expected by IOCP not to be moved.
836884
type Packet = Pin<Arc<PacketUnwrapped>>;
837885
type PacketUnwrapped = IoStatusBlock<PacketInner>;
838-
/// A wrapper around the Overlapped<Packet> structure for file I/O operation result
886+
887+
/// A wrapper around the `Overlapped<Packet>` structure for file I/O operation result
839888
#[derive(Debug)]
840889
#[repr(transparent)]
841890
pub struct FileOverlappedWrapper(Overlapped<Packet>);
@@ -867,6 +916,100 @@ impl FileOverlappedWrapper {
867916
}
868917
}
869918

919+
/// The converter is used to safely reference count the Packet owned by the poller
920+
/// when overlapped I/O operation is called successfully (the operation return TRUE or ERROR_IO_PENDING).
921+
///
922+
/// If the I/O operation return FALSE with last error not ERROR_IO_PENDING, the caller must call
923+
/// [`reclaim`] to reclaim the Packet reference count. Otherwise the Packet will be leaked.
924+
///
925+
/// Normally the caller should use helper function [`read_file_overlapped`] or [`write_file_overlapped`]
926+
/// to do the I/O operation. The helper function will call `reclaim` automatically when I/O operation failed.
927+
///
928+
/// [`reclaim`]: FileOverlappedConverter::reclaim
929+
/// [`read_file_overlapped`]: crate::os::iocp::read_file_overlapped
930+
/// [`write_file_overlapped`]: crate::os::iocp::write_file_overlapped
931+
///
932+
/// # Examples
933+
/// ```rust
934+
/// use windows_sys::Win32::{Foundation as wf, Storage::FileSystem as wsf };
935+
/// fn read_file(handle: RAWHANDLE, buf: &[u8], overlapped: FileOverlappedConverter) -> io::Result<usize> {
936+
/// let mut read = 0u32;
937+
/// // Safety: syscall
938+
/// if unsafe {
939+
/// wsf::ReadFile(
940+
/// handle,
941+
/// buf.as_ptr(),
942+
/// buf.len() as u32,
943+
/// &mut read as *mut _,
944+
/// overlapped
945+
/// .as_ptr()
946+
/// .expect("The overlapped pointer may have been used for I/O operation"),
947+
/// )
948+
/// } != wf::FALSE
949+
/// {
950+
/// return Ok(read as usize);
951+
/// }
952+
///
953+
/// let err = io::Error::last_os_error();
954+
/// match err.raw_os_error().map(|e| e as u32) {
955+
/// Some(wf::ERROR_IO_PENDING) => Err(io::ErrorKind::WouldBlock.into()),
956+
/// _ => Err(err),
957+
/// }
958+
/// match err {
959+
/// Ok(size) => Ok(size),
960+
/// Err(e) if e.kind() == io::ErrorKind::WouldBlock => Err(e),
961+
/// Err(e) => {
962+
/// overlapped.reclaim(); // reclaim the Packet reference count
963+
/// Err(e)
964+
/// }
965+
/// }
966+
/// }
967+
/// ```
968+
#[derive(Debug)]
969+
pub struct FileOverlappedConverter {
970+
ptr: *mut OVERLAPPED,
971+
owner: Option<PacketWrapper>,
972+
drop: Option<ManuallyDrop<PacketWrapper>>,
973+
}
974+
975+
impl FileOverlappedConverter {
976+
pub(crate) fn new(ptr: *mut OVERLAPPED, packet: PacketWrapper) -> Self {
977+
Self {
978+
ptr,
979+
owner: Some(packet),
980+
drop: None,
981+
}
982+
}
983+
984+
/// Get the raw pointer. The caller must ensure the pointer is used for overlapped I/O operation.
985+
pub fn as_ptr(&mut self) -> Option<*mut OVERLAPPED> {
986+
if let Some(packet) = self.owner.take() {
987+
self.drop = Some(ManuallyDrop::new(packet));
988+
}
989+
Some(self.ptr)
990+
}
991+
992+
/// Reclaim the Packet reference count when I/O operation failed.
993+
pub fn reclaim(&mut self) {
994+
if let Some(drop) = self.drop.take() {
995+
self.owner = Some(ManuallyDrop::into_inner(drop));
996+
}
997+
}
998+
}
999+
1000+
#[derive(Debug, Clone)]
1001+
#[repr(transparent)]
1002+
pub(crate) struct PacketWrapper(Packet);
1003+
1004+
impl PacketWrapper {
1005+
#[doc(hidden)]
1006+
pub fn test_ref_count(&self) -> usize {
1007+
// Safety: the object is Arc and will not be moved
1008+
let inner = unsafe { &*(&self.0 as *const Packet as *const Arc<PacketUnwrapped>) };
1009+
Arc::strong_count(inner)
1010+
}
1011+
}
1012+
8701013
pin_project! {
8711014
/// The inner type of the packet.
8721015
#[project_ref = PacketInnerProj]
@@ -1022,6 +1165,13 @@ impl PacketUnwrapped {
10221165
// Update if there is no ongoing wait.
10231166
handle.status.is_idle()
10241167
}
1168+
PacketInnerProj::File { handle, .. } => {
1169+
let mut handle = lock!(handle.lock());
1170+
1171+
// Set the new interest.
1172+
handle.interest = interest;
1173+
false
1174+
}
10251175
_ => true,
10261176
}
10271177
}
@@ -1269,44 +1419,46 @@ impl PacketUnwrapped {
12691419
status: FileCompletionStatus,
12701420
bytes_transferred: u32,
12711421
) -> io::Result<FeedEventResult> {
1272-
let inner = self.as_ref().data().project_ref();
1273-
1274-
let (handle, read, write) = match inner {
1275-
PacketInnerProj::File {
1276-
handle,
1277-
read,
1278-
write,
1279-
} => (handle, read, write),
1280-
_ => unreachable!("Should not be called on a non-file packet"),
1281-
};
1422+
let return_value;
1423+
{
1424+
let inner = self.as_ref().data().project_ref();
1425+
1426+
let (handle, read, write) = match inner {
1427+
PacketInnerProj::File {
1428+
handle,
1429+
read,
1430+
write,
1431+
} => (handle, read, write),
1432+
_ => unreachable!("Should not be called on a non-file packet"),
1433+
};
12821434

1283-
let file_state = lock!(handle.lock());
1284-
let mut event = Event::none(file_state.interest.key);
1285-
if status.is_read() {
1286-
unsafe {
1287-
(*read.get()).set_bytes_transferred(bytes_transferred);
1435+
let file_state = lock!(handle.lock());
1436+
let mut event = Event::none(file_state.interest.key);
1437+
if status.is_read() {
1438+
unsafe {
1439+
(*read.get()).set_bytes_transferred(bytes_transferred);
1440+
}
1441+
event.readable = true;
12881442
}
1289-
event.readable = true;
1290-
}
12911443

1292-
if status.is_write() {
1293-
unsafe {
1294-
(*write.get()).set_bytes_transferred(bytes_transferred);
1444+
if status.is_write() {
1445+
unsafe {
1446+
(*write.get()).set_bytes_transferred(bytes_transferred);
1447+
}
1448+
event.writable = true;
12951449
}
1296-
event.writable = true;
1297-
}
12981450

1299-
event.readable &= file_state.interest.readable;
1300-
event.writable &= file_state.interest.writable;
1301-
1302-
// If this event doesn't have anything that interests us, don't return or
1303-
// update the oneshot state.
1304-
let return_value = if event.readable || event.writable {
1305-
FeedEventResult::Event(event)
1306-
} else {
1307-
FeedEventResult::NoEvent
1308-
};
1451+
event.readable &= file_state.interest.readable;
1452+
event.writable &= file_state.interest.writable;
13091453

1454+
// If this event doesn't have anything that interests us, don't return or
1455+
// update the oneshot state.
1456+
return_value = if event.readable || event.writable {
1457+
FeedEventResult::Event(event)
1458+
} else {
1459+
FeedEventResult::NoEvent
1460+
};
1461+
}
13101462
Ok(return_value)
13111463
}
13121464

0 commit comments

Comments
 (0)