diff --git a/crates/factor-outbound-http/src/lib.rs b/crates/factor-outbound-http/src/lib.rs index ef0ef8f95..74691122a 100644 --- a/crates/factor-outbound-http/src/lib.rs +++ b/crates/factor-outbound-http/src/lib.rs @@ -22,6 +22,7 @@ use spin_factors::{ anyhow, ConfigureAppContext, Factor, FactorData, PrepareContext, RuntimeFactors, SelfInstanceBuilder, }; +use tokio::sync::Semaphore; use wasmtime_wasi_http::WasiHttpCtx; pub use wasmtime_wasi_http::{ @@ -53,13 +54,15 @@ impl Factor for OutboundHttpFactor { &self, mut ctx: ConfigureAppContext, ) -> anyhow::Result { - let connection_pooling = ctx - .take_runtime_config() - .unwrap_or_default() - .connection_pooling; + let config = ctx.take_runtime_config().unwrap_or_default(); Ok(AppState { - wasi_http_clients: wasi::HttpClients::new(connection_pooling), - connection_pooling, + wasi_http_clients: wasi::HttpClients::new(config.connection_pooling_enabled), + connection_pooling_enabled: config.connection_pooling_enabled, + concurrent_outbound_connections_semaphore: config + .max_concurrent_connections + // Permit count is the max concurrent connections + 1. + // i.e., 0 concurrent connections means 1 total connection. + .map(|n| Arc::new(Semaphore::new(n + 1))), }) } @@ -80,7 +83,11 @@ impl Factor for OutboundHttpFactor { request_interceptor: None, spin_http_client: None, wasi_http_clients: ctx.app_state().wasi_http_clients.clone(), - connection_pooling: ctx.app_state().connection_pooling, + connection_pooling_enabled: ctx.app_state().connection_pooling_enabled, + concurrent_outbound_connections_semaphore: ctx + .app_state() + .concurrent_outbound_connections_semaphore + .clone(), }) } } @@ -94,7 +101,7 @@ pub struct InstanceState { request_interceptor: Option>, // Connection-pooling client for 'fermyon:spin/http' interface // - // TODO: We could move this to `AppState` to like the + // TODO: We could move this to `AppState` like the // `wasi:http/outgoing-handler` pool for consistency, although it's probably // not a high priority given that `fermyon:spin/http` is deprecated anyway. spin_http_client: Option, @@ -103,7 +110,10 @@ pub struct InstanceState { // This is a clone of `AppState::wasi_http_clients`, meaning it is shared // among all instances of the app. wasi_http_clients: wasi::HttpClients, - connection_pooling: bool, + /// Whether connection pooling is enabled for this instance. + connection_pooling_enabled: bool, + /// A semaphore to limit the number of concurrent outbound connections. + concurrent_outbound_connections_semaphore: Option>, } impl InstanceState { @@ -185,5 +195,8 @@ impl std::fmt::Display for SelfRequestOrigin { pub struct AppState { // Connection pooling clients for `wasi:http/outgoing-handler` interface wasi_http_clients: wasi::HttpClients, - connection_pooling: bool, + /// Whether connection pooling is enabled for this app. + connection_pooling_enabled: bool, + /// A semaphore to limit the number of concurrent outbound connections. + concurrent_outbound_connections_semaphore: Option>, } diff --git a/crates/factor-outbound-http/src/runtime_config.rs b/crates/factor-outbound-http/src/runtime_config.rs index 5c2b5b3a6..9b8ecdad7 100644 --- a/crates/factor-outbound-http/src/runtime_config.rs +++ b/crates/factor-outbound-http/src/runtime_config.rs @@ -5,13 +5,16 @@ pub mod spin; #[derive(Debug)] pub struct RuntimeConfig { /// If true, enable connection pooling and reuse. - pub connection_pooling: bool, + pub connection_pooling_enabled: bool, + /// If set, limits the number of concurrent outbound connections. + pub max_concurrent_connections: Option, } impl Default for RuntimeConfig { fn default() -> Self { Self { - connection_pooling: true, + connection_pooling_enabled: true, + max_concurrent_connections: None, } } } diff --git a/crates/factor-outbound-http/src/runtime_config/spin.rs b/crates/factor-outbound-http/src/runtime_config/spin.rs index 65aa483b7..fc32c2cdc 100644 --- a/crates/factor-outbound-http/src/runtime_config/spin.rs +++ b/crates/factor-outbound-http/src/runtime_config/spin.rs @@ -6,17 +6,17 @@ use spin_factors::runtime_config::toml::GetTomlValue; /// Expects table to be in the format: /// ```toml /// [outbound_http] -/// connection_pooling = true +/// connection_pooling = true # optional, defaults to true +/// max_concurrent_requests = 10 # optional, defaults to unlimited /// ``` pub fn config_from_table( table: &impl GetTomlValue, ) -> anyhow::Result> { if let Some(outbound_http) = table.get("outbound_http") { + let outbound_http_toml = outbound_http.clone().try_into::()?; Ok(Some(super::RuntimeConfig { - connection_pooling: outbound_http - .clone() - .try_into::()? - .connection_pooling, + connection_pooling_enabled: outbound_http_toml.connection_pooling, + max_concurrent_connections: outbound_http_toml.max_concurrent_requests, })) } else { Ok(None) @@ -28,4 +28,6 @@ pub fn config_from_table( struct OutboundHttpToml { #[serde(default)] connection_pooling: bool, + #[serde(default)] + max_concurrent_requests: Option, } diff --git a/crates/factor-outbound-http/src/spin.rs b/crates/factor-outbound-http/src/spin.rs index 2f8fc428f..8370aab63 100644 --- a/crates/factor-outbound-http/src/spin.rs +++ b/crates/factor-outbound-http/src/spin.rs @@ -21,7 +21,6 @@ impl spin_http::Host for crate::InstanceState { if !req.params.is_empty() { tracing::warn!("HTTP params field is deprecated"); } - let req_url = if !uri.starts_with('/') { // Absolute URI let is_allowed = self @@ -92,13 +91,21 @@ impl spin_http::Host for crate::InstanceState { // in a single component execution let client = self.spin_http_client.get_or_insert_with(|| { let mut builder = reqwest::Client::builder(); - if !self.connection_pooling { + if !self.connection_pooling_enabled { builder = builder.pool_max_idle_per_host(0); } builder.build().unwrap() }); + // If we're limiting concurrent outbound requests, acquire a permit + // Note: since we don't have access to the underlying connection, we can only + // limit the number of concurrent requests, not connections. + let permit = match &self.concurrent_outbound_connections_semaphore { + Some(s) => s.acquire().await.ok(), + None => None, + }; let resp = client.execute(req).await.map_err(log_reqwest_error)?; + drop(permit); tracing::trace!("Returning response from outbound request to {req_url}"); span.record("http.response.status_code", resp.status().as_u16()); diff --git a/crates/factor-outbound-http/src/wasi.rs b/crates/factor-outbound-http/src/wasi.rs index 2e75bdb9e..2250cc482 100644 --- a/crates/factor-outbound-http/src/wasi.rs +++ b/crates/factor-outbound-http/src/wasi.rs @@ -28,6 +28,7 @@ use spin_factors::{wasmtime::component::ResourceTable, RuntimeFactorsInstanceSta use tokio::{ io::{AsyncRead, AsyncWrite, ReadBuf}, net::TcpStream, + sync::{OwnedSemaphorePermit, Semaphore}, time::timeout, }; use tokio_rustls::client::TlsStream; @@ -90,6 +91,9 @@ impl p3::WasiHttpCtx for InstanceState { self_request_origin: self.self_request_origin.clone(), blocked_networks: self.blocked_networks.clone(), http_clients: self.wasi_http_clients.clone(), + concurrent_outbound_connections_semaphore: self + .concurrent_outbound_connections_semaphore + .clone(), }; let config = OutgoingRequestConfig { use_tls: request.uri().scheme() == Some(&Scheme::HTTPS), @@ -282,6 +286,10 @@ impl WasiHttpView for WasiHttpImplInner<'_> { self_request_origin: self.state.self_request_origin.clone(), blocked_networks: self.state.blocked_networks.clone(), http_clients: self.state.wasi_http_clients.clone(), + concurrent_outbound_connections_semaphore: self + .state + .concurrent_outbound_connections_semaphore + .clone(), }; Ok(HostFutureIncomingResponse::Pending( wasmtime_wasi::runtime::spawn( @@ -307,6 +315,7 @@ struct RequestSender { self_request_origin: Option, request_interceptor: Option>, http_clients: HttpClients, + concurrent_outbound_connections_semaphore: Option>, } impl RequestSender { @@ -454,6 +463,8 @@ impl RequestSender { connect_timeout, tls_client_config, override_connect_addr, + concurrent_outbound_connections_semaphore: self + .concurrent_outbound_connections_semaphore, }, async move { if use_tls { @@ -520,24 +531,38 @@ impl HttpClients { } } -// We must use task-local variables for these config options when using -// `hyper_util::client::legacy::Client::request` because there's no way to plumb -// them through as parameters. Moreover, if there's already a pooled connection -// ready, we'll reuse that and ignore these options anyway. tokio::task_local! { + /// The options used when establishing a new connection. + /// + /// We must use task-local variables for these config options when using + /// `hyper_util::client::legacy::Client::request` because there's no way to plumb + /// them through as parameters. Moreover, if there's already a pooled connection + /// ready, we'll reuse that and ignore these options anyway. After each connection + /// is established, the options are dropped. static CONNECT_OPTIONS: ConnectOptions; } #[derive(Clone)] struct ConnectOptions { + /// The blocked networks configuration. blocked_networks: BlockedNetworks, + /// Timeout for establishing a TCP connection. connect_timeout: Duration, + /// TLS client configuration to use, if any. tls_client_config: Option, + /// If set, override the address to connect to instead of using the given `uri`'s authority. override_connect_addr: Option, + /// A semaphore to limit the number of concurrent outbound connections. + concurrent_outbound_connections_semaphore: Option>, } impl ConnectOptions { - async fn connect_tcp(&self, uri: &Uri, default_port: u16) -> Result { + /// Establish a TCP connection to the given URI and default port. + async fn connect_tcp( + &self, + uri: &Uri, + default_port: u16, + ) -> Result { let mut socket_addrs = match self.override_connect_addr { Some(override_connect_addr) => vec![override_connect_addr], None => { @@ -572,7 +597,13 @@ impl ConnectOptions { return Err(ErrorCode::DestinationIpProhibited); } - timeout(self.connect_timeout, TcpStream::connect(&*socket_addrs)) + // If we're limiting concurrent outbound requests, acquire a permit + let permit = match &self.concurrent_outbound_connections_semaphore { + Some(s) => s.clone().acquire_owned().await.ok(), + None => None, + }; + + let stream = timeout(self.connect_timeout, TcpStream::connect(&*socket_addrs)) .await .map_err(|_| ErrorCode::ConnectionTimeout)? .map_err(|err| match err.kind() { @@ -580,14 +611,19 @@ impl ConnectOptions { dns_error("address not available".into(), 0) } _ => ErrorCode::ConnectionRefused, - }) + })?; + Ok(PermittedTcpStream { + inner: stream, + _permit: permit, + }) } + /// Establish a TLS connection to the given URI and default port. async fn connect_tls( &self, uri: &Uri, default_port: u16, - ) -> Result, ErrorCode> { + ) -> Result, ErrorCode> { let tcp_stream = self.connect_tcp(uri, default_port).await?; let mut tls_client_config = self.tls_client_config.as_deref().unwrap().clone(); @@ -597,7 +633,7 @@ impl ConnectOptions { let domain = rustls::pki_types::ServerName::try_from(uri.host().unwrap()) .map_err(|e| { tracing::warn!("dns lookup error: {e:?}"); - dns_error("invalid dns name".to_string(), 0) + dns_error("invalid dns name".into(), 0) })? .to_owned(); connector.connect(domain, tcp_stream).await.map_err(|e| { @@ -607,20 +643,22 @@ impl ConnectOptions { } } +/// A connector the uses `ConnectOptions` #[derive(Clone)] struct HttpConnector; impl HttpConnector { - async fn connect(uri: Uri) -> Result, ErrorCode> { + async fn connect(uri: Uri) -> Result, ErrorCode> { let stream = CONNECT_OPTIONS.get().connect_tcp(&uri, 80).await?; Ok(TokioIo::new(stream)) } } impl Service for HttpConnector { - type Response = TokioIo; + type Response = TokioIo; type Error = ErrorCode; - type Future = Pin, ErrorCode>> + Send>>; + type Future = + Pin, ErrorCode>> + Send>>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) @@ -631,6 +669,7 @@ impl Service for HttpConnector { } } +/// A connector that establishes TLS connections using `rustls` and `ConnectOptions`. #[derive(Clone)] struct HttpsConnector; @@ -655,7 +694,7 @@ impl Service for HttpsConnector { } } -struct RustlsStream(TlsStream); +struct RustlsStream(TlsStream); impl Connection for RustlsStream { fn connected(&self) -> Connected { @@ -710,6 +749,54 @@ impl AsyncWrite for RustlsStream { } } +/// A TCP stream that holds an optional permit indicating that it is allowed to exist. +struct PermittedTcpStream { + /// The wrapped TCP stream. + inner: TcpStream, + /// A permit indicating that this stream is allowed to exist. + /// + /// When this stream is dropped, the permit is also dropped, allowing another + /// connection to be established. + _permit: Option, +} + +impl Connection for PermittedTcpStream { + fn connected(&self) -> Connected { + self.inner.connected() + } +} + +impl AsyncRead for PermittedTcpStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.get_mut().inner).poll_read(cx, buf) + } +} + +impl AsyncWrite for PermittedTcpStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.get_mut().inner).poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.get_mut().inner).poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.get_mut().inner).poll_shutdown(cx) + } +} + /// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request. fn hyper_request_error(err: hyper::Error) -> ErrorCode { // If there's a source, we might be able to extract a wasi-http error from it.