Skip to content

Commit f20022c

Browse files
committed
Limit tcp connections instead of http requests
Signed-off-by: Ryan Levick <[email protected]>
1 parent 622dbe9 commit f20022c

File tree

5 files changed

+110
-37
lines changed

5 files changed

+110
-37
lines changed

crates/factor-outbound-http/src/lib.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,10 @@ impl Factor for OutboundHttpFactor {
5656
Ok(AppState {
5757
wasi_http_clients: wasi::HttpClients::new(config.connection_pooling_enabled),
5858
connection_pooling_enabled: config.connection_pooling_enabled,
59-
concurrent_outbound_requests_semaphore: config
60-
.max_concurrent_requests
61-
// Permit count is the max concurrent requests + 1.
62-
// i.e., 0 concurrent requests means 1 total request.
59+
concurrent_outbound_connections_semaphore: config
60+
.max_concurrent_connections
61+
// Permit count is the max concurrent connections + 1.
62+
// i.e., 0 concurrent connections means 1 total connection.
6363
.map(|n| Arc::new(Semaphore::new(n + 1))),
6464
})
6565
}
@@ -82,9 +82,9 @@ impl Factor for OutboundHttpFactor {
8282
spin_http_client: None,
8383
wasi_http_clients: ctx.app_state().wasi_http_clients.clone(),
8484
connection_pooling_enabled: ctx.app_state().connection_pooling_enabled,
85-
concurrent_outbound_requests_semaphore: ctx
85+
concurrent_outbound_connections_semaphore: ctx
8686
.app_state()
87-
.concurrent_outbound_requests_semaphore
87+
.concurrent_outbound_connections_semaphore
8888
.clone(),
8989
})
9090
}
@@ -110,8 +110,8 @@ pub struct InstanceState {
110110
wasi_http_clients: wasi::HttpClients,
111111
/// Whether connection pooling is enabled for this instance.
112112
connection_pooling_enabled: bool,
113-
/// A semaphore to limit the number of concurrent outbound requests.
114-
concurrent_outbound_requests_semaphore: Option<Arc<Semaphore>>,
113+
/// A semaphore to limit the number of concurrent outbound connections.
114+
concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
115115
}
116116

117117
impl InstanceState {
@@ -195,6 +195,6 @@ pub struct AppState {
195195
wasi_http_clients: wasi::HttpClients,
196196
/// Whether connection pooling is enabled for this app.
197197
connection_pooling_enabled: bool,
198-
/// A semaphore to limit the number of concurrent outbound requests.
199-
concurrent_outbound_requests_semaphore: Option<Arc<Semaphore>>,
198+
/// A semaphore to limit the number of concurrent outbound connections.
199+
concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
200200
}

crates/factor-outbound-http/src/runtime_config.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@ pub mod spin;
66
pub struct RuntimeConfig {
77
/// If true, enable connection pooling and reuse.
88
pub connection_pooling_enabled: bool,
9-
/// If set, limits the number of concurrent outbound requests.
10-
pub max_concurrent_requests: Option<usize>,
9+
/// If set, limits the number of concurrent outbound connections.
10+
pub max_concurrent_connections: Option<usize>,
1111
}
1212

1313
impl Default for RuntimeConfig {
1414
fn default() -> Self {
1515
Self {
1616
connection_pooling_enabled: true,
17-
max_concurrent_requests: None,
17+
max_concurrent_connections: None,
1818
}
1919
}
2020
}

crates/factor-outbound-http/src/runtime_config/spin.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pub fn config_from_table(
1616
let outbound_http_toml = outbound_http.clone().try_into::<OutboundHttpToml>()?;
1717
Ok(Some(super::RuntimeConfig {
1818
connection_pooling_enabled: outbound_http_toml.connection_pooling,
19-
max_concurrent_requests: outbound_http_toml.max_concurrent_requests,
19+
max_concurrent_connections: outbound_http_toml.max_concurrent_requests,
2020
}))
2121
} else {
2222
Ok(None)

crates/factor-outbound-http/src/spin.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ impl spin_http::Host for crate::InstanceState {
9898
});
9999

100100
// If we're limiting concurrent outbound requests, acquire a permit
101-
let permit = match &self.concurrent_outbound_requests_semaphore {
101+
let permit = match &self.concurrent_outbound_connections_semaphore {
102102
Some(s) => s.acquire().await.ok(),
103103
None => None,
104104
};

crates/factor-outbound-http/src/wasi.rs

Lines changed: 95 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use spin_factors::{wasmtime::component::ResourceTable, RuntimeFactorsInstanceSta
2626
use tokio::{
2727
io::{AsyncRead, AsyncWrite, ReadBuf},
2828
net::TcpStream,
29-
sync::Semaphore,
29+
sync::{OwnedSemaphorePermit, Semaphore},
3030
time::timeout,
3131
};
3232
use tokio_rustls::client::TlsStream;
@@ -129,9 +129,9 @@ impl WasiHttpView for WasiHttpImplInner<'_> {
129129
self_request_origin: self.state.self_request_origin.clone(),
130130
blocked_networks: self.state.blocked_networks.clone(),
131131
http_clients: self.state.wasi_http_clients.clone(),
132-
concurrent_outbound_requests_semaphore: self
132+
concurrent_outbound_connections_semaphore: self
133133
.state
134-
.concurrent_outbound_requests_semaphore
134+
.concurrent_outbound_connections_semaphore
135135
.clone(),
136136
};
137137
Ok(HostFutureIncomingResponse::Pending(
@@ -158,7 +158,7 @@ struct RequestSender {
158158
self_request_origin: Option<SelfRequestOrigin>,
159159
request_interceptor: Option<Arc<dyn OutboundHttpInterceptor>>,
160160
http_clients: HttpClients,
161-
concurrent_outbound_requests_semaphore: Option<Arc<Semaphore>>,
161+
concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
162162
}
163163

164164
impl RequestSender {
@@ -300,12 +300,18 @@ impl RequestSender {
300300
None
301301
};
302302

303+
// If we're limiting concurrent outbound requests, acquire a permit
304+
let permit = match &self.concurrent_outbound_connections_semaphore {
305+
Some(s) => s.clone().acquire_owned().await.ok().map(Arc::new),
306+
None => None,
307+
};
303308
let resp = CONNECT_OPTIONS.scope(
304309
ConnectOptions {
305310
blocked_networks: self.blocked_networks,
306311
connect_timeout,
307312
tls_client_config,
308313
override_connect_addr,
314+
permit,
309315
},
310316
async move {
311317
if use_tls {
@@ -326,17 +332,11 @@ impl RequestSender {
326332
},
327333
);
328334

329-
// If we're limiting concurrent outbound requests, acquire a permit
330-
let permit = match &self.concurrent_outbound_requests_semaphore {
331-
Some(s) => s.acquire().await.ok(),
332-
None => None,
333-
};
334335
let resp = timeout(first_byte_timeout, resp)
335336
.await
336337
.map_err(|_| ErrorCode::ConnectionReadTimeout)?
337338
.map_err(hyper_legacy_request_error)?
338339
.map(|body| body.map_err(hyper_request_error).boxed());
339-
drop(permit);
340340

341341
tracing::Span::current().record("http.response.status_code", resp.status().as_u16());
342342

@@ -378,24 +378,40 @@ impl HttpClients {
378378
}
379379
}
380380

381-
// We must use task-local variables for these config options when using
382-
// `hyper_util::client::legacy::Client::request` because there's no way to plumb
383-
// them through as parameters. Moreover, if there's already a pooled connection
384-
// ready, we'll reuse that and ignore these options anyway.
385381
tokio::task_local! {
382+
/// The options used when establishing a new connection.
383+
///
384+
/// We must use task-local variables for these config options when using
385+
/// `hyper_util::client::legacy::Client::request` because there's no way to plumb
386+
/// them through as parameters. Moreover, if there's already a pooled connection
387+
/// ready, we'll reuse that and ignore these options anyway. After each connection
388+
/// is established, the options are dropped.
386389
static CONNECT_OPTIONS: ConnectOptions;
387390
}
388391

389392
#[derive(Clone)]
390393
struct ConnectOptions {
394+
/// The blocked networks configuration.
391395
blocked_networks: BlockedNetworks,
396+
/// Timeout for establishing a TCP connection.
392397
connect_timeout: Duration,
398+
/// TLS client configuration to use, if any.
393399
tls_client_config: Option<TlsClientConfig>,
400+
/// If set, override the address to connect to instead of using the given `uri`'s authority.
394401
override_connect_addr: Option<SocketAddr>,
402+
/// A permit for this connection
403+
///
404+
/// If there is a permit, it should be dropped when the connection is closed.
405+
permit: Option<Arc<OwnedSemaphorePermit>>,
395406
}
396407

397408
impl ConnectOptions {
398-
async fn connect_tcp(&self, uri: &Uri, default_port: u16) -> Result<TcpStream, ErrorCode> {
409+
/// Establish a TCP connection to the given URI and default port.
410+
async fn connect_tcp(
411+
&self,
412+
uri: &Uri,
413+
default_port: u16,
414+
) -> Result<PermittedTcpStream, ErrorCode> {
399415
let mut socket_addrs = match self.override_connect_addr {
400416
Some(override_connect_addr) => vec![override_connect_addr],
401417
None => {
@@ -430,22 +446,27 @@ impl ConnectOptions {
430446
return Err(ErrorCode::DestinationIpProhibited);
431447
}
432448

433-
timeout(self.connect_timeout, TcpStream::connect(&*socket_addrs))
449+
let stream = timeout(self.connect_timeout, TcpStream::connect(&*socket_addrs))
434450
.await
435451
.map_err(|_| ErrorCode::ConnectionTimeout)?
436452
.map_err(|err| match err.kind() {
437453
std::io::ErrorKind::AddrNotAvailable => {
438454
dns_error("address not available".into(), 0)
439455
}
440456
_ => ErrorCode::ConnectionRefused,
441-
})
457+
})?;
458+
Ok(PermittedTcpStream {
459+
inner: stream,
460+
_permit: self.permit.clone(),
461+
})
442462
}
443463

464+
/// Establish a TLS connection to the given URI and default port.
444465
async fn connect_tls(
445466
&self,
446467
uri: &Uri,
447468
default_port: u16,
448-
) -> Result<TlsStream<TcpStream>, ErrorCode> {
469+
) -> Result<TlsStream<PermittedTcpStream>, ErrorCode> {
449470
let tcp_stream = self.connect_tcp(uri, default_port).await?;
450471

451472
let mut tls_client_config = self.tls_client_config.as_deref().unwrap().clone();
@@ -465,20 +486,22 @@ impl ConnectOptions {
465486
}
466487
}
467488

489+
/// A connector the uses `ConnectOptions`
468490
#[derive(Clone)]
469491
struct HttpConnector;
470492

471493
impl HttpConnector {
472-
async fn connect(uri: Uri) -> Result<TokioIo<TcpStream>, ErrorCode> {
494+
async fn connect(uri: Uri) -> Result<TokioIo<PermittedTcpStream>, ErrorCode> {
473495
let stream = CONNECT_OPTIONS.get().connect_tcp(&uri, 80).await?;
474496
Ok(TokioIo::new(stream))
475497
}
476498
}
477499

478500
impl Service<Uri> for HttpConnector {
479-
type Response = TokioIo<TcpStream>;
501+
type Response = TokioIo<PermittedTcpStream>;
480502
type Error = ErrorCode;
481-
type Future = Pin<Box<dyn Future<Output = Result<TokioIo<TcpStream>, ErrorCode>> + Send>>;
503+
type Future =
504+
Pin<Box<dyn Future<Output = Result<TokioIo<PermittedTcpStream>, ErrorCode>> + Send>>;
482505

483506
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
484507
Poll::Ready(Ok(()))
@@ -489,6 +512,7 @@ impl Service<Uri> for HttpConnector {
489512
}
490513
}
491514

515+
/// A connector that establishes TLS connections using `rustls` and `ConnectOptions`.
492516
#[derive(Clone)]
493517
struct HttpsConnector;
494518

@@ -513,7 +537,7 @@ impl Service<Uri> for HttpsConnector {
513537
}
514538
}
515539

516-
struct RustlsStream(TlsStream<TcpStream>);
540+
struct RustlsStream(TlsStream<PermittedTcpStream>);
517541

518542
impl Connection for RustlsStream {
519543
fn connected(&self) -> Connected {
@@ -568,6 +592,55 @@ impl AsyncWrite for RustlsStream {
568592
}
569593
}
570594

595+
/// A TCP stream that holds an optional permit indicating that it is allowed to exist.
596+
struct PermittedTcpStream {
597+
inner: TcpStream,
598+
_permit: Option<Arc<OwnedSemaphorePermit>>,
599+
}
600+
601+
impl PermittedTcpStream {
602+
fn connected(&self) -> Connected {
603+
self.inner.connected()
604+
}
605+
}
606+
607+
impl Connection for PermittedTcpStream {
608+
fn connected(&self) -> Connected {
609+
self.inner.connected()
610+
}
611+
}
612+
613+
impl AsyncRead for PermittedTcpStream {
614+
fn poll_read(
615+
self: Pin<&mut Self>,
616+
cx: &mut Context<'_>,
617+
buf: &mut ReadBuf<'_>,
618+
) -> Poll<std::io::Result<()>> {
619+
Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
620+
}
621+
}
622+
623+
impl AsyncWrite for PermittedTcpStream {
624+
fn poll_write(
625+
self: Pin<&mut Self>,
626+
cx: &mut Context<'_>,
627+
buf: &[u8],
628+
) -> Poll<Result<usize, std::io::Error>> {
629+
Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
630+
}
631+
632+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
633+
Pin::new(&mut self.get_mut().inner).poll_flush(cx)
634+
}
635+
636+
fn poll_shutdown(
637+
self: Pin<&mut Self>,
638+
cx: &mut Context<'_>,
639+
) -> Poll<Result<(), std::io::Error>> {
640+
Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
641+
}
642+
}
643+
571644
/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request.
572645
fn hyper_request_error(err: hyper::Error) -> ErrorCode {
573646
// If there's a source, we might be able to extract a wasi-http error from it.

0 commit comments

Comments
 (0)