Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions crates/factor-outbound-http/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use spin_factors::{
anyhow, ConfigureAppContext, Factor, FactorData, PrepareContext, RuntimeFactors,
SelfInstanceBuilder,
};
use tokio::sync::Semaphore;
use wasmtime_wasi_http::WasiHttpCtx;

pub use wasmtime_wasi_http::{
Expand Down Expand Up @@ -53,13 +54,15 @@ impl Factor for OutboundHttpFactor {
&self,
mut ctx: ConfigureAppContext<T, Self>,
) -> anyhow::Result<Self::AppState> {
let connection_pooling = ctx
.take_runtime_config()
.unwrap_or_default()
.connection_pooling;
let config = ctx.take_runtime_config().unwrap_or_default();
Ok(AppState {
wasi_http_clients: wasi::HttpClients::new(connection_pooling),
connection_pooling,
wasi_http_clients: wasi::HttpClients::new(config.connection_pooling_enabled),
connection_pooling_enabled: config.connection_pooling_enabled,
concurrent_outbound_connections_semaphore: config
.max_concurrent_connections
// Permit count is the max concurrent connections + 1.
// i.e., 0 concurrent connections means 1 total connection.
.map(|n| Arc::new(Semaphore::new(n + 1))),
Comment on lines +63 to +65
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a huge deal since we're talking about rather large limits but I don't understand this interpretation. I would expect "0 concurrent connections" to mean "you can't ever connect" (which should perhaps be invalid).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I have a different interpretation. A concurrent request is by definition one that is happening during another request. 0 concurrent requests would therefore mean "no requests happening at the same time" not "no requests at all". I think we should just pick which ever interpretation is most convenient and use that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How many requests are involved in "2 concurrent requests"?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A concurrent request is by definition one that is happening during another request

That is strictly true, but I think that doing the +1 makes everything else more confusing. The first request becomes concurrent once there is two total requests.

I'd vote for removing the + 1

})
}

Expand All @@ -80,7 +83,11 @@ impl Factor for OutboundHttpFactor {
request_interceptor: None,
spin_http_client: None,
wasi_http_clients: ctx.app_state().wasi_http_clients.clone(),
connection_pooling: ctx.app_state().connection_pooling,
connection_pooling_enabled: ctx.app_state().connection_pooling_enabled,
concurrent_outbound_connections_semaphore: ctx
.app_state()
.concurrent_outbound_connections_semaphore
.clone(),
})
}
}
Expand All @@ -94,7 +101,7 @@ pub struct InstanceState {
request_interceptor: Option<Arc<dyn OutboundHttpInterceptor>>,
// Connection-pooling client for 'fermyon:spin/http' interface
//
// TODO: We could move this to `AppState` to like the
// TODO: We could move this to `AppState` like the
// `wasi:http/outgoing-handler` pool for consistency, although it's probably
// not a high priority given that `fermyon:spin/http` is deprecated anyway.
spin_http_client: Option<reqwest::Client>,
Expand All @@ -103,7 +110,10 @@ pub struct InstanceState {
// This is a clone of `AppState::wasi_http_clients`, meaning it is shared
// among all instances of the app.
wasi_http_clients: wasi::HttpClients,
connection_pooling: bool,
/// Whether connection pooling is enabled for this instance.
connection_pooling_enabled: bool,
/// A semaphore to limit the number of concurrent outbound connections.
concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
}

impl InstanceState {
Expand Down Expand Up @@ -185,5 +195,8 @@ impl std::fmt::Display for SelfRequestOrigin {
pub struct AppState {
// Connection pooling clients for `wasi:http/outgoing-handler` interface
wasi_http_clients: wasi::HttpClients,
connection_pooling: bool,
/// Whether connection pooling is enabled for this app.
connection_pooling_enabled: bool,
/// A semaphore to limit the number of concurrent outbound connections.
concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
}
7 changes: 5 additions & 2 deletions crates/factor-outbound-http/src/runtime_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@ pub mod spin;
#[derive(Debug)]
pub struct RuntimeConfig {
/// If true, enable connection pooling and reuse.
pub connection_pooling: bool,
pub connection_pooling_enabled: bool,
/// If set, limits the number of concurrent outbound connections.
pub max_concurrent_connections: Option<usize>,
}

impl Default for RuntimeConfig {
fn default() -> Self {
Self {
connection_pooling: true,
connection_pooling_enabled: true,
max_concurrent_connections: None,
}
}
}
12 changes: 7 additions & 5 deletions crates/factor-outbound-http/src/runtime_config/spin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@ use spin_factors::runtime_config::toml::GetTomlValue;
/// Expects table to be in the format:
/// ```toml
/// [outbound_http]
/// connection_pooling = true
/// connection_pooling = true # optional, defaults to true
/// max_concurrent_requests = 10 # optional, defaults to unlimited
/// ```
pub fn config_from_table(
table: &impl GetTomlValue,
) -> anyhow::Result<Option<super::RuntimeConfig>> {
if let Some(outbound_http) = table.get("outbound_http") {
let outbound_http_toml = outbound_http.clone().try_into::<OutboundHttpToml>()?;
Ok(Some(super::RuntimeConfig {
connection_pooling: outbound_http
.clone()
.try_into::<OutboundHttpToml>()?
.connection_pooling,
connection_pooling_enabled: outbound_http_toml.connection_pooling,
max_concurrent_connections: outbound_http_toml.max_concurrent_requests,
}))
} else {
Ok(None)
Expand All @@ -28,4 +28,6 @@ pub fn config_from_table(
struct OutboundHttpToml {
#[serde(default)]
connection_pooling: bool,
#[serde(default)]
max_concurrent_requests: Option<usize>,
}
11 changes: 9 additions & 2 deletions crates/factor-outbound-http/src/spin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ impl spin_http::Host for crate::InstanceState {
if !req.params.is_empty() {
tracing::warn!("HTTP params field is deprecated");
}

let req_url = if !uri.starts_with('/') {
// Absolute URI
let is_allowed = self
Expand Down Expand Up @@ -92,13 +91,21 @@ impl spin_http::Host for crate::InstanceState {
// in a single component execution
let client = self.spin_http_client.get_or_insert_with(|| {
let mut builder = reqwest::Client::builder();
if !self.connection_pooling {
if !self.connection_pooling_enabled {
builder = builder.pool_max_idle_per_host(0);
}
builder.build().unwrap()
});

// If we're limiting concurrent outbound requests, acquire a permit
// Note: since we don't have access to the underlying connection, we can only
// limit the number of concurrent requests, not connections.
let permit = match &self.concurrent_outbound_connections_semaphore {
Some(s) => s.acquire().await.ok(),
None => None,
};
let resp = client.execute(req).await.map_err(log_reqwest_error)?;
drop(permit);

tracing::trace!("Returning response from outbound request to {req_url}");
span.record("http.response.status_code", resp.status().as_u16());
Expand Down
113 changes: 100 additions & 13 deletions crates/factor-outbound-http/src/wasi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use spin_factors::{wasmtime::component::ResourceTable, RuntimeFactorsInstanceSta
use tokio::{
io::{AsyncRead, AsyncWrite, ReadBuf},
net::TcpStream,
sync::{OwnedSemaphorePermit, Semaphore},
time::timeout,
};
use tokio_rustls::client::TlsStream;
Expand Down Expand Up @@ -90,6 +91,9 @@ impl p3::WasiHttpCtx for InstanceState {
self_request_origin: self.self_request_origin.clone(),
blocked_networks: self.blocked_networks.clone(),
http_clients: self.wasi_http_clients.clone(),
concurrent_outbound_connections_semaphore: self
.concurrent_outbound_connections_semaphore
.clone(),
};
let config = OutgoingRequestConfig {
use_tls: request.uri().scheme() == Some(&Scheme::HTTPS),
Expand Down Expand Up @@ -282,6 +286,10 @@ impl WasiHttpView for WasiHttpImplInner<'_> {
self_request_origin: self.state.self_request_origin.clone(),
blocked_networks: self.state.blocked_networks.clone(),
http_clients: self.state.wasi_http_clients.clone(),
concurrent_outbound_connections_semaphore: self
.state
.concurrent_outbound_connections_semaphore
.clone(),
};
Ok(HostFutureIncomingResponse::Pending(
wasmtime_wasi::runtime::spawn(
Expand All @@ -307,6 +315,7 @@ struct RequestSender {
self_request_origin: Option<SelfRequestOrigin>,
request_interceptor: Option<Arc<dyn OutboundHttpInterceptor>>,
http_clients: HttpClients,
concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
}

impl RequestSender {
Expand Down Expand Up @@ -454,6 +463,8 @@ impl RequestSender {
connect_timeout,
tls_client_config,
override_connect_addr,
concurrent_outbound_connections_semaphore: self
.concurrent_outbound_connections_semaphore,
},
async move {
if use_tls {
Expand Down Expand Up @@ -520,24 +531,38 @@ impl HttpClients {
}
}

// We must use task-local variables for these config options when using
// `hyper_util::client::legacy::Client::request` because there's no way to plumb
// them through as parameters. Moreover, if there's already a pooled connection
// ready, we'll reuse that and ignore these options anyway.
tokio::task_local! {
/// The options used when establishing a new connection.
///
/// We must use task-local variables for these config options when using
/// `hyper_util::client::legacy::Client::request` because there's no way to plumb
/// them through as parameters. Moreover, if there's already a pooled connection
/// ready, we'll reuse that and ignore these options anyway. After each connection
/// is established, the options are dropped.
static CONNECT_OPTIONS: ConnectOptions;
}

#[derive(Clone)]
struct ConnectOptions {
/// The blocked networks configuration.
blocked_networks: BlockedNetworks,
/// Timeout for establishing a TCP connection.
connect_timeout: Duration,
/// TLS client configuration to use, if any.
tls_client_config: Option<TlsClientConfig>,
/// If set, override the address to connect to instead of using the given `uri`'s authority.
override_connect_addr: Option<SocketAddr>,
/// A semaphore to limit the number of concurrent outbound connections.
concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
}

impl ConnectOptions {
async fn connect_tcp(&self, uri: &Uri, default_port: u16) -> Result<TcpStream, ErrorCode> {
/// Establish a TCP connection to the given URI and default port.
async fn connect_tcp(
&self,
uri: &Uri,
default_port: u16,
) -> Result<PermittedTcpStream, ErrorCode> {
let mut socket_addrs = match self.override_connect_addr {
Some(override_connect_addr) => vec![override_connect_addr],
None => {
Expand Down Expand Up @@ -572,22 +597,33 @@ impl ConnectOptions {
return Err(ErrorCode::DestinationIpProhibited);
}

timeout(self.connect_timeout, TcpStream::connect(&*socket_addrs))
// If we're limiting concurrent outbound requests, acquire a permit
let permit = match &self.concurrent_outbound_connections_semaphore {
Some(s) => s.clone().acquire_owned().await.ok(),
None => None,
};

let stream = timeout(self.connect_timeout, TcpStream::connect(&*socket_addrs))
.await
.map_err(|_| ErrorCode::ConnectionTimeout)?
.map_err(|err| match err.kind() {
std::io::ErrorKind::AddrNotAvailable => {
dns_error("address not available".into(), 0)
}
_ => ErrorCode::ConnectionRefused,
})
})?;
Ok(PermittedTcpStream {
inner: stream,
_permit: permit,
})
}

/// Establish a TLS connection to the given URI and default port.
async fn connect_tls(
&self,
uri: &Uri,
default_port: u16,
) -> Result<TlsStream<TcpStream>, ErrorCode> {
) -> Result<TlsStream<PermittedTcpStream>, ErrorCode> {
let tcp_stream = self.connect_tcp(uri, default_port).await?;

let mut tls_client_config = self.tls_client_config.as_deref().unwrap().clone();
Expand All @@ -597,7 +633,7 @@ impl ConnectOptions {
let domain = rustls::pki_types::ServerName::try_from(uri.host().unwrap())
.map_err(|e| {
tracing::warn!("dns lookup error: {e:?}");
dns_error("invalid dns name".to_string(), 0)
dns_error("invalid dns name".into(), 0)
})?
.to_owned();
connector.connect(domain, tcp_stream).await.map_err(|e| {
Expand All @@ -607,20 +643,22 @@ impl ConnectOptions {
}
}

/// A connector the uses `ConnectOptions`
#[derive(Clone)]
struct HttpConnector;

impl HttpConnector {
async fn connect(uri: Uri) -> Result<TokioIo<TcpStream>, ErrorCode> {
async fn connect(uri: Uri) -> Result<TokioIo<PermittedTcpStream>, ErrorCode> {
let stream = CONNECT_OPTIONS.get().connect_tcp(&uri, 80).await?;
Ok(TokioIo::new(stream))
}
}

impl Service<Uri> for HttpConnector {
type Response = TokioIo<TcpStream>;
type Response = TokioIo<PermittedTcpStream>;
type Error = ErrorCode;
type Future = Pin<Box<dyn Future<Output = Result<TokioIo<TcpStream>, ErrorCode>> + Send>>;
type Future =
Pin<Box<dyn Future<Output = Result<TokioIo<PermittedTcpStream>, ErrorCode>> + Send>>;

fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
Expand All @@ -631,6 +669,7 @@ impl Service<Uri> for HttpConnector {
}
}

/// A connector that establishes TLS connections using `rustls` and `ConnectOptions`.
#[derive(Clone)]
struct HttpsConnector;

Expand All @@ -655,7 +694,7 @@ impl Service<Uri> for HttpsConnector {
}
}

struct RustlsStream(TlsStream<TcpStream>);
struct RustlsStream(TlsStream<PermittedTcpStream>);

impl Connection for RustlsStream {
fn connected(&self) -> Connected {
Expand Down Expand Up @@ -710,6 +749,54 @@ impl AsyncWrite for RustlsStream {
}
}

/// A TCP stream that holds an optional permit indicating that it is allowed to exist.
struct PermittedTcpStream {
/// The wrapped TCP stream.
inner: TcpStream,
/// A permit indicating that this stream is allowed to exist.
///
/// When this stream is dropped, the permit is also dropped, allowing another
/// connection to be established.
_permit: Option<OwnedSemaphorePermit>,
}

impl Connection for PermittedTcpStream {
fn connected(&self) -> Connected {
self.inner.connected()
}
}

impl AsyncRead for PermittedTcpStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
}
}

impl AsyncWrite for PermittedTcpStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.get_mut().inner).poll_flush(cx)
}

fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
}
}

/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request.
fn hyper_request_error(err: hyper::Error) -> ErrorCode {
// If there's a source, we might be able to extract a wasi-http error from it.
Expand Down