-
Notifications
You must be signed in to change notification settings - Fork 290
Set a limit on max number of concurrent outbound http requests #3285
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ use spin_factors::{wasmtime::component::ResourceTable, RuntimeFactorsInstanceSta | |
use tokio::{ | ||
io::{AsyncRead, AsyncWrite, ReadBuf}, | ||
net::TcpStream, | ||
sync::{OwnedSemaphorePermit, Semaphore}, | ||
time::timeout, | ||
}; | ||
use tokio_rustls::client::TlsStream; | ||
|
@@ -90,6 +91,9 @@ impl p3::WasiHttpCtx for InstanceState { | |
self_request_origin: self.self_request_origin.clone(), | ||
blocked_networks: self.blocked_networks.clone(), | ||
http_clients: self.wasi_http_clients.clone(), | ||
concurrent_outbound_connections_semaphore: self | ||
.concurrent_outbound_connections_semaphore | ||
.clone(), | ||
}; | ||
let config = OutgoingRequestConfig { | ||
use_tls: request.uri().scheme() == Some(&Scheme::HTTPS), | ||
|
@@ -282,6 +286,10 @@ impl WasiHttpView for WasiHttpImplInner<'_> { | |
self_request_origin: self.state.self_request_origin.clone(), | ||
blocked_networks: self.state.blocked_networks.clone(), | ||
http_clients: self.state.wasi_http_clients.clone(), | ||
concurrent_outbound_connections_semaphore: self | ||
.state | ||
.concurrent_outbound_connections_semaphore | ||
.clone(), | ||
}; | ||
Ok(HostFutureIncomingResponse::Pending( | ||
wasmtime_wasi::runtime::spawn( | ||
|
@@ -307,6 +315,7 @@ struct RequestSender { | |
self_request_origin: Option<SelfRequestOrigin>, | ||
request_interceptor: Option<Arc<dyn OutboundHttpInterceptor>>, | ||
http_clients: HttpClients, | ||
concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>, | ||
} | ||
|
||
impl RequestSender { | ||
|
@@ -448,12 +457,18 @@ impl RequestSender { | |
None | ||
}; | ||
|
||
// If we're limiting concurrent outbound requests, acquire a permit | ||
let permit = match self.concurrent_outbound_connections_semaphore { | ||
Some(s) => s.acquire_owned().await.ok().map(Arc::new), | ||
None => None, | ||
}; | ||
|
||
let resp = CONNECT_OPTIONS.scope( | ||
ConnectOptions { | ||
blocked_networks: self.blocked_networks, | ||
connect_timeout, | ||
tls_client_config, | ||
override_connect_addr, | ||
permit, | ||
}, | ||
async move { | ||
if use_tls { | ||
|
@@ -520,24 +535,40 @@ impl HttpClients { | |
} | ||
} | ||
|
||
// We must use task-local variables for these config options when using | ||
// `hyper_util::client::legacy::Client::request` because there's no way to plumb | ||
// them through as parameters. Moreover, if there's already a pooled connection | ||
// ready, we'll reuse that and ignore these options anyway. | ||
tokio::task_local! { | ||
/// The options used when establishing a new connection. | ||
/// | ||
/// We must use task-local variables for these config options when using | ||
/// `hyper_util::client::legacy::Client::request` because there's no way to plumb | ||
/// them through as parameters. Moreover, if there's already a pooled connection | ||
/// ready, we'll reuse that and ignore these options anyway. After each connection | ||
/// is established, the options are dropped. | ||
static CONNECT_OPTIONS: ConnectOptions; | ||
} | ||
|
||
#[derive(Clone)] | ||
struct ConnectOptions { | ||
/// The blocked networks configuration. | ||
blocked_networks: BlockedNetworks, | ||
/// Timeout for establishing a TCP connection. | ||
connect_timeout: Duration, | ||
/// TLS client configuration to use, if any. | ||
tls_client_config: Option<TlsClientConfig>, | ||
/// If set, override the address to connect to instead of using the given `uri`'s authority. | ||
override_connect_addr: Option<SocketAddr>, | ||
/// A permit for this connection | ||
/// | ||
/// If there is a permit, it should be dropped when the connection is closed. | ||
permit: Option<Arc<OwnedSemaphorePermit>>, | ||
} | ||
|
||
impl ConnectOptions { | ||
async fn connect_tcp(&self, uri: &Uri, default_port: u16) -> Result<TcpStream, ErrorCode> { | ||
/// Establish a TCP connection to the given URI and default port. | ||
async fn connect_tcp( | ||
&self, | ||
uri: &Uri, | ||
default_port: u16, | ||
) -> Result<PermittedTcpStream, ErrorCode> { | ||
let mut socket_addrs = match self.override_connect_addr { | ||
Some(override_connect_addr) => vec![override_connect_addr], | ||
None => { | ||
|
@@ -572,22 +603,27 @@ impl ConnectOptions { | |
return Err(ErrorCode::DestinationIpProhibited); | ||
} | ||
|
||
timeout(self.connect_timeout, TcpStream::connect(&*socket_addrs)) | ||
let stream = timeout(self.connect_timeout, TcpStream::connect(&*socket_addrs)) | ||
.await | ||
.map_err(|_| ErrorCode::ConnectionTimeout)? | ||
.map_err(|err| match err.kind() { | ||
std::io::ErrorKind::AddrNotAvailable => { | ||
dns_error("address not available".into(), 0) | ||
} | ||
_ => ErrorCode::ConnectionRefused, | ||
}) | ||
})?; | ||
Ok(PermittedTcpStream { | ||
inner: stream, | ||
_permit: self.permit.clone(), | ||
}) | ||
} | ||
|
||
/// Establish a TLS connection to the given URI and default port. | ||
async fn connect_tls( | ||
&self, | ||
uri: &Uri, | ||
default_port: u16, | ||
) -> Result<TlsStream<TcpStream>, ErrorCode> { | ||
) -> Result<TlsStream<PermittedTcpStream>, ErrorCode> { | ||
let tcp_stream = self.connect_tcp(uri, default_port).await?; | ||
|
||
let mut tls_client_config = self.tls_client_config.as_deref().unwrap().clone(); | ||
|
@@ -597,7 +633,7 @@ impl ConnectOptions { | |
let domain = rustls::pki_types::ServerName::try_from(uri.host().unwrap()) | ||
.map_err(|e| { | ||
tracing::warn!("dns lookup error: {e:?}"); | ||
dns_error("invalid dns name".to_string(), 0) | ||
dns_error("invalid dns name".into(), 0) | ||
})? | ||
.to_owned(); | ||
connector.connect(domain, tcp_stream).await.map_err(|e| { | ||
|
@@ -607,20 +643,22 @@ impl ConnectOptions { | |
} | ||
} | ||
|
||
/// A connector the uses `ConnectOptions` | ||
#[derive(Clone)] | ||
struct HttpConnector; | ||
|
||
impl HttpConnector { | ||
async fn connect(uri: Uri) -> Result<TokioIo<TcpStream>, ErrorCode> { | ||
async fn connect(uri: Uri) -> Result<TokioIo<PermittedTcpStream>, ErrorCode> { | ||
let stream = CONNECT_OPTIONS.get().connect_tcp(&uri, 80).await?; | ||
Ok(TokioIo::new(stream)) | ||
} | ||
} | ||
|
||
impl Service<Uri> for HttpConnector { | ||
type Response = TokioIo<TcpStream>; | ||
type Response = TokioIo<PermittedTcpStream>; | ||
type Error = ErrorCode; | ||
type Future = Pin<Box<dyn Future<Output = Result<TokioIo<TcpStream>, ErrorCode>> + Send>>; | ||
type Future = | ||
Pin<Box<dyn Future<Output = Result<TokioIo<PermittedTcpStream>, ErrorCode>> + Send>>; | ||
|
||
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { | ||
Poll::Ready(Ok(())) | ||
|
@@ -631,6 +669,7 @@ impl Service<Uri> for HttpConnector { | |
} | ||
} | ||
|
||
/// A connector that establishes TLS connections using `rustls` and `ConnectOptions`. | ||
#[derive(Clone)] | ||
struct HttpsConnector; | ||
|
||
|
@@ -655,7 +694,7 @@ impl Service<Uri> for HttpsConnector { | |
} | ||
} | ||
|
||
struct RustlsStream(TlsStream<TcpStream>); | ||
struct RustlsStream(TlsStream<PermittedTcpStream>); | ||
|
||
impl Connection for RustlsStream { | ||
fn connected(&self) -> Connected { | ||
|
@@ -710,6 +749,54 @@ impl AsyncWrite for RustlsStream { | |
} | ||
} | ||
|
||
/// A TCP stream that holds an optional permit indicating that it is allowed to exist. | ||
struct PermittedTcpStream { | ||
/// The wrapped TCP stream. | ||
inner: TcpStream, | ||
/// A permit indicating that this stream is allowed to exist. | ||
/// | ||
/// When this stream is dropped, the permit is also dropped, allowing another | ||
/// connection to be established. | ||
_permit: Option<Arc<OwnedSemaphorePermit>>, | ||
} | ||
|
||
impl Connection for PermittedTcpStream { | ||
fn connected(&self) -> Connected { | ||
self.inner.connected() | ||
} | ||
} | ||
|
||
impl AsyncRead for PermittedTcpStream { | ||
fn poll_read( | ||
self: Pin<&mut Self>, | ||
cx: &mut Context<'_>, | ||
buf: &mut ReadBuf<'_>, | ||
) -> Poll<std::io::Result<()>> { | ||
Pin::new(&mut self.get_mut().inner).poll_read(cx, buf) | ||
} | ||
} | ||
|
||
impl AsyncWrite for PermittedTcpStream { | ||
fn poll_write( | ||
self: Pin<&mut Self>, | ||
cx: &mut Context<'_>, | ||
buf: &[u8], | ||
) -> Poll<Result<usize, std::io::Error>> { | ||
Pin::new(&mut self.get_mut().inner).poll_write(cx, buf) | ||
} | ||
|
||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> { | ||
Pin::new(&mut self.get_mut().inner).poll_flush(cx) | ||
} | ||
|
||
fn poll_shutdown( | ||
self: Pin<&mut Self>, | ||
cx: &mut Context<'_>, | ||
) -> Poll<Result<(), std::io::Error>> { | ||
Pin::new(&mut self.get_mut().inner).poll_shutdown(cx) | ||
} | ||
} | ||
|
||
/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request. | ||
fn hyper_request_error(err: hyper::Error) -> ErrorCode { | ||
// If there's a source, we might be able to extract a wasi-http error from it. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a huge deal since we're talking about rather large limits but I don't understand this interpretation. I would expect "0 concurrent connections" to mean "you can't ever connect" (which should perhaps be invalid).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess I have a different interpretation. A concurrent request is by definition one that is happening during another request. 0 concurrent requests would therefore mean "no requests happening at the same time" not "no requests at all". I think we should just pick which ever interpretation is most convenient and use that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How many requests are involved in "2 concurrent requests"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is strictly true, but I think that doing the +1 makes everything else more confusing. The first request becomes concurrent once there is two total requests.
I'd vote for removing the
+ 1