Skip to content

Commit f7d8dad

Browse files
committed
Solve FilePacket lifetime maybe longer than poller issue
1 parent 5e9aad7 commit f7d8dad

File tree

6 files changed

+891
-538
lines changed

6 files changed

+891
-538
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ features = [
5353
"Win32_System_LibraryLoader",
5454
"Win32_System_Threading",
5555
"Win32_System_WindowsProgramming",
56+
"Win32_System_Pipes",
5657
]
5758

5859
[target.'cfg(target_os = "hermit")'.dependencies.hermit-abi]
@@ -71,4 +72,3 @@ signal-hook = "0.3.17"
7172

7273
[target.'cfg(windows)'.dev-dependencies]
7374
tempfile = "3.7"
74-
windows-sys = { version = "0.60", features = ["Win32_System_Pipes"] }

src/iocp/afd.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -468,13 +468,14 @@ unsafe impl<T> Completion for IoStatusBlock<T> {
468468
}
469469

470470
impl<T: FileOverlapped> FileOverlapped for IoStatusBlock<T> {
471+
#[inline]
471472
fn file_read_offset() -> usize {
472473
T::file_read_offset() + std::mem::offset_of!(IoStatusBlock<T>, data)
473474
}
474475

476+
#[inline]
475477
fn file_write_offset() -> usize {
476-
let data_offset = std::mem::offset_of!(IoStatusBlock<T>, data);
477-
T::file_write_offset() + data_offset
478+
T::file_write_offset() + std::mem::offset_of!(IoStatusBlock<T>, data)
478479
}
479480
}
480481

src/iocp/mod.rs

Lines changed: 217 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ use std::cell::UnsafeCell;
5050
use std::collections::hash_map::{Entry, HashMap};
5151
use std::ffi::c_void;
5252
use std::marker::PhantomPinned;
53-
use std::mem::{forget, MaybeUninit};
53+
use std::mem::{forget, ManuallyDrop, MaybeUninit};
5454
use std::os::windows::io::{
5555
AsHandle, AsRawHandle, AsRawSocket, BorrowedHandle, BorrowedSocket, RawHandle, RawSocket,
5656
};
@@ -94,7 +94,7 @@ pub(super) struct Poller {
9494
/// The state of the waitable handles registered with this poller.
9595
waitables: RwLock<HashMap<RawHandle, Packet>>,
9696

97-
/// The state of the waitable handles registered with this poller.
97+
/// The state of the overlapped files registered with this poller.
9898
files: RwLock<HashMap<RawHandle, Packet>>,
9999

100100
/// Sockets with pending updates.
@@ -435,16 +435,31 @@ impl Poller {
435435
}
436436

437437
/// Add a file to the poller.
438+
///
438439
/// File handle work on PollMode::Edge mode. The IOCP continue to poll the events unitl
439440
/// the file is closed. The caller must use the overlapped pointer return in IocpFilePacket
440-
/// as overlapped paramter for I/O operation. The Packet do not need to increase Arc count because
441-
/// the call can trigger events through I/O operation without update intrest events as long as the
442-
/// file handle has been registered with the IOCP. So the Packet lifetime is ended with calling [`remove_file`].
443-
/// Any I/O operation using return overlapped pointer return in IocpFilePacket is undefined behavior.
441+
/// as overlapped paramter for I/O operation. The Packet need to increase Arc count every time the I/O operation
442+
/// is performed success (return TRUE or FALSE with ERROR_IO_PENDING in last error), otherwise the Arc count do
443+
/// not need to increase if I/O operation fail because [`IocpFilePacket`] can exist after the poller is dropped.
444+
/// And I/O operation still be valid with the overlapped pointer after the poller is dropped. [`FileOverlappedConverter`]
445+
/// can help to manage the Arc count to avoid memory leak.
446+
///
447+
/// Normally, the caller use I/O helper function like [`read_file_overlapped`], [`write_file_overlapped`] or
448+
/// [`connect_named_pipe_overlapped`] to perform I/O operation to avoid the complexity of managing the Arc count.
449+
///
450+
/// [`read_file_overlapped`]: crate::os::iocp::read_file_overlapped
451+
/// [`write_file_overlapped`]: crate::os::iocp::write_file_overlapped
452+
/// [`connect_named_pipe_overlapped`]: crate::os::iocp::connect_named_pipe_overlapped
453+
///
454+
/// The call can trigger events through I/O operation without update intrest events as long as the
455+
/// file handle has been registered with the IOCP. The Packet lifetime is ended with conditions: [`remove_file`]
456+
/// is called, I/O operation is polled, and [`IocpFilePacket`] is dropped.
457+
///
458+
/// IocpFilePacket will return both read and write overlapped pointer through [`FileOverlappedConverter::as_ptr()`]
459+
/// no matter what intrest events are.
444460
///
445-
/// IocpFilePacket will return both read and write overlapped pointer no matter what intrest events are.
446-
/// The caller need to use the correct overlapped pointer for I/O operation. Such as: the read overlapped
447-
/// pointer can be used for read operations, and the write overlapped pointer can be used for write operations.
461+
/// The caller need to use the correct overlapped converter for I/O operation. Such as: the read overlapped
462+
/// converter can be used for read operations, and the write overlapped converter can be used for write operations.
448463
pub(super) fn add_file(
449464
&self,
450465
handle: RawHandle,
@@ -485,24 +500,57 @@ impl Poller {
485500
}
486501
}
487502

488-
let (read, write, file_handle) = match handle_state.as_ref().data().project_ref() {
489-
PacketInnerProj::File {
490-
read,
491-
write,
492-
handle,
493-
} => (read.get(), write.get(), handle),
494-
_ => unreachable!("PacketInner should always be File here"),
495-
};
503+
let read_ptr;
504+
let write_ptr;
505+
{
506+
let (read, write, file_handle) = match handle_state.as_ref().data().project_ref() {
507+
PacketInnerProj::File {
508+
read,
509+
write,
510+
handle,
511+
} => (read.get(), write.get(), handle),
512+
_ => unreachable!("PacketInner should always be File here"),
513+
};
496514

497-
let file_state = lock!(file_handle.lock());
498-
// Register the file handle with the I/O completion port.
499-
self.port
500-
.register(&*file_state, true, port::CompletionKeyType::File)?;
515+
let file_state = lock!(file_handle.lock());
516+
// Register the file handle with the I/O completion port.
517+
self.port
518+
.register(&*file_state, true, port::CompletionKeyType::File)?;
519+
520+
read_ptr = unsafe { (*read).as_ptr() };
521+
write_ptr = unsafe { (*write).as_ptr() };
522+
}
501523

502-
let iocp_packet = unsafe { IocpFilePacket::new((*read).as_ptr(), (*write).as_ptr()) };
524+
let iocp_packet =
525+
unsafe { IocpFilePacket::new(read_ptr, write_ptr, PacketWrapper(handle_state)) };
503526
Ok(iocp_packet)
504527
}
505528

529+
pub(super) fn modify_file(&self, handle: RawHandle, interest: Event) -> io::Result<()> {
530+
#[cfg(feature = "tracing")]
531+
tracing::trace!(
532+
"modify_file: handle={:?}, file={:p}, ev={:?}",
533+
self.port,
534+
handle,
535+
interest
536+
);
537+
538+
// Get a reference to the source.
539+
let source = {
540+
let sources = lock!(self.files.read());
541+
542+
sources
543+
.get(&handle)
544+
.cloned()
545+
.ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?
546+
};
547+
548+
// Set the new event.
549+
source.as_ref().set_events(interest, PollMode::Edge);
550+
551+
Ok(())
552+
}
553+
506554
/// Remove a file from the poller.
507555
pub(super) fn remove_file(&self, handle: RawHandle) -> io::Result<()> {
508556
#[cfg(feature = "tracing")]
@@ -835,7 +883,8 @@ impl CompletionPacket {
835883
/// It needs to be pinned, since it contains data that is expected by IOCP not to be moved.
836884
type Packet = Pin<Arc<PacketUnwrapped>>;
837885
type PacketUnwrapped = IoStatusBlock<PacketInner>;
838-
/// A wrapper around the Overlapped<Packet> structure for file I/O operation result
886+
887+
/// A wrapper around the `Overlapped<Packet>` structure for file I/O operation result
839888
#[derive(Debug)]
840889
#[repr(transparent)]
841890
pub struct FileOverlappedWrapper(Overlapped<Packet>);
@@ -867,6 +916,110 @@ impl FileOverlappedWrapper {
867916
}
868917
}
869918

919+
/// The converter is used to safely reference count the Packet owned by the poller
920+
/// when overlapped I/O operation is called successfully (the operation return TRUE or ERROR_IO_PENDING).
921+
///
922+
/// If the I/O operation return FALSE with last error not ERROR_IO_PENDING, the caller must call
923+
/// [`reclaim`] to reclaim the Packet reference count. Otherwise the Packet will be leaked.
924+
///
925+
/// Normally the caller should use helper function [`read_file_overlapped`] or [`write_file_overlapped`]
926+
/// to do the I/O operation. The helper function will call `reclaim` automatically when I/O operation failed.
927+
///
928+
/// [`reclaim`]: FileOverlappedConverter::reclaim
929+
/// [`read_file_overlapped`]: crate::os::iocp::read_file_overlapped
930+
/// [`write_file_overlapped`]: crate::os::iocp::write_file_overlapped
931+
///
932+
/// # Examples
933+
///
934+
/// ```no_run
935+
/// use polling::os::iocp::FileOverlappedConverter;
936+
/// use std::{io, os::windows::io::RawHandle};
937+
/// use windows_sys::Win32::{Foundation as wf, Storage::FileSystem as wsf};
938+
/// fn read_file(
939+
/// handle: RawHandle,
940+
/// buf: &mut [u8],
941+
/// mut overlapped: FileOverlappedConverter,
942+
/// ) -> io::Result<usize> {
943+
/// let mut read = 0u32;
944+
/// // Safety: syscall
945+
/// if unsafe {
946+
/// wsf::ReadFile(
947+
/// handle,
948+
/// buf.as_mut_ptr(),
949+
/// buf.len() as u32,
950+
/// &mut read as *mut _,
951+
/// overlapped
952+
/// .as_ptr()
953+
/// .expect("The overlapped pointer may have been used for I/O operation"),
954+
/// )
955+
/// } != wf::FALSE
956+
/// {
957+
/// return Ok(read as usize);
958+
/// }
959+
///
960+
/// let err = io::Error::last_os_error();
961+
/// let err: io::Result<usize> = err
962+
/// .raw_os_error()
963+
/// .map(|e| match (e as u32) {
964+
/// wf::ERROR_IO_PENDING => Err(io::ErrorKind::WouldBlock.into()),
965+
/// _ => Err(err),
966+
/// })
967+
/// .unwrap();
968+
/// match err {
969+
/// Err(e) if e.kind() == io::ErrorKind::WouldBlock => Err(e),
970+
/// Err(e) => {
971+
/// overlapped.reclaim(); // reclaim the Packet reference count
972+
/// Err(e)
973+
/// }
974+
/// _ => unreachable!(),
975+
/// }
976+
/// }
977+
/// ```
978+
#[derive(Debug)]
979+
pub struct FileOverlappedConverter {
980+
ptr: *mut OVERLAPPED,
981+
owner: Option<PacketWrapper>,
982+
drop: Option<ManuallyDrop<PacketWrapper>>,
983+
}
984+
985+
impl FileOverlappedConverter {
986+
pub(crate) fn new(ptr: *mut OVERLAPPED, packet: PacketWrapper) -> Self {
987+
Self {
988+
ptr,
989+
owner: Some(packet),
990+
drop: None,
991+
}
992+
}
993+
994+
/// Get the raw pointer. The caller must ensure the pointer is used for overlapped I/O operation.
995+
pub fn as_ptr(&mut self) -> Option<*mut OVERLAPPED> {
996+
if let Some(packet) = self.owner.take() {
997+
self.drop = Some(ManuallyDrop::new(packet));
998+
}
999+
Some(self.ptr)
1000+
}
1001+
1002+
/// Reclaim the Packet reference count when I/O operation failed.
1003+
pub fn reclaim(&mut self) {
1004+
if let Some(drop) = self.drop.take() {
1005+
self.owner = Some(ManuallyDrop::into_inner(drop));
1006+
}
1007+
}
1008+
}
1009+
1010+
#[derive(Debug, Clone)]
1011+
#[repr(transparent)]
1012+
pub(crate) struct PacketWrapper(Packet);
1013+
1014+
impl PacketWrapper {
1015+
#[doc(hidden)]
1016+
pub fn test_ref_count(&self) -> usize {
1017+
// Safety: the object is Arc and will not be moved
1018+
let inner = unsafe { &*(&self.0 as *const Packet as *const Arc<PacketUnwrapped>) };
1019+
Arc::strong_count(inner)
1020+
}
1021+
}
1022+
8701023
pin_project! {
8711024
/// The inner type of the packet.
8721025
#[project_ref = PacketInnerProj]
@@ -1022,6 +1175,13 @@ impl PacketUnwrapped {
10221175
// Update if there is no ongoing wait.
10231176
handle.status.is_idle()
10241177
}
1178+
PacketInnerProj::File { handle, .. } => {
1179+
let mut handle = lock!(handle.lock());
1180+
1181+
// Set the new interest.
1182+
handle.interest = interest;
1183+
false
1184+
}
10251185
_ => true,
10261186
}
10271187
}
@@ -1269,44 +1429,46 @@ impl PacketUnwrapped {
12691429
status: FileCompletionStatus,
12701430
bytes_transferred: u32,
12711431
) -> io::Result<FeedEventResult> {
1272-
let inner = self.as_ref().data().project_ref();
1273-
1274-
let (handle, read, write) = match inner {
1275-
PacketInnerProj::File {
1276-
handle,
1277-
read,
1278-
write,
1279-
} => (handle, read, write),
1280-
_ => unreachable!("Should not be called on a non-file packet"),
1281-
};
1432+
let return_value;
1433+
{
1434+
let inner = self.as_ref().data().project_ref();
1435+
1436+
let (handle, read, write) = match inner {
1437+
PacketInnerProj::File {
1438+
handle,
1439+
read,
1440+
write,
1441+
} => (handle, read, write),
1442+
_ => unreachable!("Should not be called on a non-file packet"),
1443+
};
12821444

1283-
let file_state = lock!(handle.lock());
1284-
let mut event = Event::none(file_state.interest.key);
1285-
if status.is_read() {
1286-
unsafe {
1287-
(*read.get()).set_bytes_transferred(bytes_transferred);
1445+
let file_state = lock!(handle.lock());
1446+
let mut event = Event::none(file_state.interest.key);
1447+
if status.is_read() {
1448+
unsafe {
1449+
(*read.get()).set_bytes_transferred(bytes_transferred);
1450+
}
1451+
event.readable = true;
12881452
}
1289-
event.readable = true;
1290-
}
12911453

1292-
if status.is_write() {
1293-
unsafe {
1294-
(*write.get()).set_bytes_transferred(bytes_transferred);
1454+
if status.is_write() {
1455+
unsafe {
1456+
(*write.get()).set_bytes_transferred(bytes_transferred);
1457+
}
1458+
event.writable = true;
12951459
}
1296-
event.writable = true;
1297-
}
12981460

1299-
event.readable &= file_state.interest.readable;
1300-
event.writable &= file_state.interest.writable;
1301-
1302-
// If this event doesn't have anything that interests us, don't return or
1303-
// update the oneshot state.
1304-
let return_value = if event.readable || event.writable {
1305-
FeedEventResult::Event(event)
1306-
} else {
1307-
FeedEventResult::NoEvent
1308-
};
1461+
event.readable &= file_state.interest.readable;
1462+
event.writable &= file_state.interest.writable;
13091463

1464+
// If this event doesn't have anything that interests us, don't return or
1465+
// update the oneshot state.
1466+
return_value = if event.readable || event.writable {
1467+
FeedEventResult::Event(event)
1468+
} else {
1469+
FeedEventResult::NoEvent
1470+
};
1471+
}
13101472
Ok(return_value)
13111473
}
13121474

0 commit comments

Comments
 (0)