diff --git a/Cargo.lock b/Cargo.lock index 551be2e..39134dd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1869,7 +1869,7 @@ checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" [[package]] name = "spin-executor" -version = "3.1.0" +version = "3.1.1" dependencies = [ "futures", "once_cell", @@ -1878,7 +1878,7 @@ dependencies = [ [[package]] name = "spin-macro" -version = "3.1.0" +version = "3.1.1" dependencies = [ "anyhow", "bytes", @@ -1889,7 +1889,7 @@ dependencies = [ [[package]] name = "spin-sdk" -version = "3.1.0" +version = "3.1.1" dependencies = [ "anyhow", "async-trait", diff --git a/Cargo.toml b/Cargo.toml index 04de50d..4658e5c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -81,7 +81,7 @@ wasmtime-wasi-http = "18.0.1" wit-component = "0.200.0" [workspace.package] -version = "3.1.0" +version = "3.1.1" authors = ["Fermyon Engineering "] edition = "2021" license = "Apache-2.0 WITH LLVM-exception" diff --git a/crates/executor/src/lib.rs b/crates/executor/src/lib.rs index 1fca0db..4486921 100644 --- a/crates/executor/src/lib.rs +++ b/crates/executor/src/lib.rs @@ -25,8 +25,8 @@ type Wrapped = Arc>>; static WAKERS: Mutex> = Mutex::new(Vec::new()); -/// Handle to a Pollable pushed using `push_waker` which may be used to cancel -/// and drop the Pollable. +/// Handle to a Pollable registered using `push_waker_and_get_token` which may +/// be used to cancel and drop the Pollable. pub struct CancelToken(Wrapped); impl CancelToken { @@ -36,8 +36,8 @@ impl CancelToken { } } -/// Handle to a Pollable pushed using `push_waker` which, when dropped, will -/// cancel and drop the Pollable. +/// Handle to a Pollable registered using `push_waker_and_get_token` which, when +/// dropped, will cancel and drop the Pollable. pub struct CancelOnDropToken(Wrapped); impl From for CancelOnDropToken { @@ -52,16 +52,27 @@ impl Drop for CancelOnDropToken { } } -/// Push a Pollable and Waker to WAKERS. -pub fn push_waker(pollable: io::poll::Pollable, waker: Waker) -> CancelToken { +/// Register a `Pollable` and `Waker` to be polled as part of the [`run`] event +/// loop. +pub fn push_waker(pollable: io::poll::Pollable, waker: Waker) { + _ = push_waker_and_get_token(pollable, waker); +} + +/// Register a `Pollable` and `Waker` to be polled as part of the [`run`] event +/// loop and retrieve a [`CancelToken`] to cancel the registration later, if +/// desired. +pub fn push_waker_and_get_token(pollable: io::poll::Pollable, waker: Waker) -> CancelToken { let wrapped = Arc::new(Mutex::new(Some(pollable))); WAKERS.lock().unwrap().push((wrapped.clone(), waker)); CancelToken(wrapped) } -/// Run the specified future to completion blocking until it yields a result. +/// Run the specified future to completion, blocking until it yields a result. /// -/// Based on an executor using `wasi::io/poll/poll-list`, +/// This will alternate between polling the specified future and polling any +/// `Pollable`s registered using [`push_waker`] or [`push_waker_and_get_token`] +/// using `wasi::io/poll/poll-list`. It will panic if the future returns +/// `Poll::Pending` without having registered at least one `Pollable`. pub fn run(future: impl Future) -> T { futures::pin_mut!(future); struct DummyWaker; @@ -85,6 +96,8 @@ pub fn run(future: impl Future) -> T { }) .collect::>(); + assert!(!wakers.is_empty()); + let pollables = wakers .iter() .map(|(_, pollable, _)| pollable) diff --git a/src/http/executor.rs b/src/http/executor.rs index a985adb..027fb84 100644 --- a/src/http/executor.rs +++ b/src/http/executor.rs @@ -54,11 +54,12 @@ pub(crate) fn outgoing_body(body: OutgoingBody) -> impl Sink, Error = St loop { match stream.check_write() { Ok(0) => { - outgoing.cancel_token = - Some(CancelOnDropToken::from(spin_executor::push_waker( + outgoing.cancel_token = Some(CancelOnDropToken::from( + spin_executor::push_waker_and_get_token( stream.subscribe(), context.waker().clone(), - ))); + ), + )); break Poll::Pending; } Ok(count) => { @@ -126,10 +127,12 @@ pub(crate) fn outgoing_request_send( if let Some(response) = response.get() { Poll::Ready(response.unwrap()) } else { - state.cancel_token = Some(CancelOnDropToken::from(spin_executor::push_waker( - response.subscribe(), - context.waker().clone(), - ))); + state.cancel_token = Some(CancelOnDropToken::from( + spin_executor::push_waker_and_get_token( + response.subscribe(), + context.waker().clone(), + ), + )); Poll::Pending } } @@ -170,11 +173,12 @@ pub fn incoming_body( match stream.read(READ_SIZE) { Ok(buffer) => { if buffer.is_empty() { - incoming.cancel_token = - Some(CancelOnDropToken::from(spin_executor::push_waker( + incoming.cancel_token = Some(CancelOnDropToken::from( + spin_executor::push_waker_and_get_token( stream.subscribe(), context.waker().clone(), - ))); + ), + )); Poll::Pending } else { Poll::Ready(Some(Ok(buffer))) diff --git a/src/llm.rs b/src/llm.rs index 162fda4..53aa530 100644 --- a/src/llm.rs +++ b/src/llm.rs @@ -54,7 +54,7 @@ pub enum InferencingModel<'a> { Other(&'a str), } -impl<'a> std::fmt::Display for InferencingModel<'a> { +impl std::fmt::Display for InferencingModel<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let str = match self { InferencingModel::Llama2Chat => "llama2-chat", @@ -100,7 +100,7 @@ pub enum EmbeddingModel<'a> { Other(&'a str), } -impl<'a> std::fmt::Display for EmbeddingModel<'a> { +impl std::fmt::Display for EmbeddingModel<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let str = match self { EmbeddingModel::AllMiniLmL6V2 => "all-minilm-l6-v2",