diff --git a/crates/executor/src/lib.rs b/crates/executor/src/lib.rs index f04dbfa..1fca0db 100644 --- a/crates/executor/src/lib.rs +++ b/crates/executor/src/lib.rs @@ -1,6 +1,7 @@ use bindings::wasi::io; use std::future::Future; use std::mem; +use std::ops::DerefMut; use std::sync::{Arc, Mutex}; use std::task::{Context, Poll, Wake, Waker}; @@ -20,11 +21,42 @@ impl std::fmt::Display for io::streams::Error { impl std::error::Error for io::streams::Error {} -static WAKERS: Mutex> = Mutex::new(Vec::new()); +type Wrapped = Arc>>; + +static WAKERS: Mutex> = Mutex::new(Vec::new()); + +/// Handle to a Pollable pushed using `push_waker` which may be used to cancel +/// and drop the Pollable. +pub struct CancelToken(Wrapped); + +impl CancelToken { + /// Cancel and drop the Pollable. + pub fn cancel(self) { + drop(self.0.lock().unwrap().take()) + } +} + +/// Handle to a Pollable pushed using `push_waker` which, when dropped, will +/// cancel and drop the Pollable. +pub struct CancelOnDropToken(Wrapped); + +impl From for CancelOnDropToken { + fn from(token: CancelToken) -> Self { + Self(token.0) + } +} + +impl Drop for CancelOnDropToken { + fn drop(&mut self) { + drop(self.0.lock().unwrap().take()) + } +} /// Push a Pollable and Waker to WAKERS. -pub fn push_waker(pollable: io::poll::Pollable, waker: Waker) { - WAKERS.lock().unwrap().push((pollable, waker)); +pub fn push_waker(pollable: io::poll::Pollable, waker: Waker) -> CancelToken { + let wrapped = Arc::new(Mutex::new(Some(pollable))); + WAKERS.lock().unwrap().push((wrapped.clone(), waker)); + CancelToken(wrapped) } /// Run the specified future to completion blocking until it yields a result. @@ -45,13 +77,17 @@ pub fn run(future: impl Future) -> T { Poll::Pending => { let mut new_wakers = Vec::new(); - let wakers = mem::take::>(&mut WAKERS.lock().unwrap()); - - assert!(!wakers.is_empty()); + let wakers = mem::take(WAKERS.lock().unwrap().deref_mut()) + .into_iter() + .filter_map(|(wrapped, waker)| { + let pollable = wrapped.lock().unwrap().take(); + pollable.map(|pollable| (wrapped, pollable, waker)) + }) + .collect::>(); let pollables = wakers .iter() - .map(|(pollable, _)| pollable) + .map(|(_, pollable, _)| pollable) .collect::>(); let mut ready = vec![false; wakers.len()]; @@ -60,11 +96,12 @@ pub fn run(future: impl Future) -> T { ready[usize::try_from(index).unwrap()] = true; } - for (ready, (pollable, waker)) in ready.into_iter().zip(wakers) { + for (ready, (wrapped, pollable, waker)) in ready.into_iter().zip(wakers) { if ready { waker.wake() } else { - new_wakers.push((pollable, waker)); + *wrapped.lock().unwrap() = Some(pollable); + new_wakers.push((wrapped, waker)); } } diff --git a/src/http/executor.rs b/src/http/executor.rs index c661c83..a985adb 100644 --- a/src/http/executor.rs +++ b/src/http/executor.rs @@ -1,6 +1,7 @@ use crate::wit::wasi::http0_2_0::outgoing_handler; use crate::wit::wasi::http0_2_0::types::{ - ErrorCode, IncomingBody, IncomingResponse, OutgoingBody, OutgoingRequest, + ErrorCode, FutureIncomingResponse, IncomingBody, IncomingResponse, OutgoingBody, + OutgoingRequest, }; use spin_executor::bindings::wasi::io; @@ -8,7 +9,7 @@ use spin_executor::bindings::wasi::io::streams::{InputStream, OutputStream, Stre use futures::{future, sink, stream, Sink, Stream}; -pub use spin_executor::run; +pub use spin_executor::{run, CancelOnDropToken}; use std::cell::RefCell; use std::future::Future; @@ -18,11 +19,16 @@ use std::task::Poll; const READ_SIZE: u64 = 16 * 1024; pub(crate) fn outgoing_body(body: OutgoingBody) -> impl Sink, Error = StreamError> { - struct Outgoing(Option<(OutputStream, OutgoingBody)>); + struct Outgoing { + stream_and_body: Option<(OutputStream, OutgoingBody)>, + cancel_token: Option, + } impl Drop for Outgoing { fn drop(&mut self) { - if let Some((stream, body)) = self.0.take() { + drop(self.cancel_token.take()); + + if let Some((stream, body)) = self.stream_and_body.take() { drop(stream); _ = OutgoingBody::finish(body, None); } @@ -30,25 +36,29 @@ pub(crate) fn outgoing_body(body: OutgoingBody) -> impl Sink, Error = St } let stream = body.write().expect("response body should be writable"); - let pair = Rc::new(RefCell::new(Outgoing(Some((stream, body))))); + let outgoing = Rc::new(RefCell::new(Outgoing { + stream_and_body: Some((stream, body)), + cancel_token: None, + })); sink::unfold((), { move |(), chunk: Vec| { future::poll_fn({ let mut offset = 0; let mut flushing = false; - let pair = pair.clone(); + let outgoing = outgoing.clone(); move |context| { - let pair = pair.borrow(); - let (stream, _) = &pair.0.as_ref().unwrap(); + let mut outgoing = outgoing.borrow_mut(); + let (stream, _) = &outgoing.stream_and_body.as_ref().unwrap(); loop { match stream.check_write() { Ok(0) => { - spin_executor::push_waker( - stream.subscribe(), - context.waker().clone(), - ); + outgoing.cancel_token = + Some(CancelOnDropToken::from(spin_executor::push_waker( + stream.subscribe(), + context.waker().clone(), + ))); break Poll::Pending; } Ok(count) => { @@ -93,14 +103,33 @@ pub(crate) fn outgoing_body(body: OutgoingBody) -> impl Sink, Error = St pub(crate) fn outgoing_request_send( request: OutgoingRequest, ) -> impl Future> { + struct State { + response: Option>, + cancel_token: Option, + } + + impl Drop for State { + fn drop(&mut self) { + drop(self.cancel_token.take()); + drop(self.response.take()); + } + } + let response = outgoing_handler::handle(request, None); + let mut state = State { + response: Some(response), + cancel_token: None, + }; future::poll_fn({ - move |context| match &response { + move |context| match &state.response.as_ref().unwrap() { Ok(response) => { if let Some(response) = response.get() { Poll::Ready(response.unwrap()) } else { - spin_executor::push_waker(response.subscribe(), context.waker().clone()); + state.cancel_token = Some(CancelOnDropToken::from(spin_executor::push_waker( + response.subscribe(), + context.waker().clone(), + ))); Poll::Pending } } @@ -113,11 +142,16 @@ pub(crate) fn outgoing_request_send( pub fn incoming_body( body: IncomingBody, ) -> impl Stream, io::streams::Error>> { - struct Incoming(Option<(InputStream, IncomingBody)>); + struct Incoming { + stream_and_body: Option<(InputStream, IncomingBody)>, + cancel_token: Option, + } impl Drop for Incoming { fn drop(&mut self) { - if let Some((stream, body)) = self.0.take() { + drop(self.cancel_token.take()); + + if let Some((stream, body)) = self.stream_and_body.take() { drop(stream); IncomingBody::finish(body); } @@ -126,14 +160,21 @@ pub fn incoming_body( stream::poll_fn({ let stream = body.stream().expect("response body should be readable"); - let pair = Incoming(Some((stream, body))); + let mut incoming = Incoming { + stream_and_body: Some((stream, body)), + cancel_token: None, + }; move |context| { - if let Some((stream, _)) = &pair.0 { + if let Some((stream, _)) = &incoming.stream_and_body { match stream.read(READ_SIZE) { Ok(buffer) => { if buffer.is_empty() { - spin_executor::push_waker(stream.subscribe(), context.waker().clone()); + incoming.cancel_token = + Some(CancelOnDropToken::from(spin_executor::push_waker( + stream.subscribe(), + context.waker().clone(), + ))); Poll::Pending } else { Poll::Ready(Some(Ok(buffer)))