Skip to content

Commit 68e5471

Browse files
committed
Limit tcp connections instead of http requests
Signed-off-by: Ryan Levick <[email protected]>
1 parent 5ec1a9d commit 68e5471

File tree

5 files changed

+113
-38
lines changed

5 files changed

+113
-38
lines changed

crates/factor-outbound-http/src/lib.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ impl Factor for OutboundHttpFactor {
5858
Ok(AppState {
5959
wasi_http_clients: wasi::HttpClients::new(config.connection_pooling_enabled),
6060
connection_pooling_enabled: config.connection_pooling_enabled,
61-
concurrent_outbound_requests_semaphore: config
62-
.max_concurrent_requests
63-
// Permit count is the max concurrent requests + 1.
64-
// i.e., 0 concurrent requests means 1 total request.
61+
concurrent_outbound_connections_semaphore: config
62+
.max_concurrent_connections
63+
// Permit count is the max concurrent connections + 1.
64+
// i.e., 0 concurrent connections means 1 total connection.
6565
.map(|n| Arc::new(Semaphore::new(n + 1))),
6666
})
6767
}
@@ -84,9 +84,9 @@ impl Factor for OutboundHttpFactor {
8484
spin_http_client: None,
8585
wasi_http_clients: ctx.app_state().wasi_http_clients.clone(),
8686
connection_pooling_enabled: ctx.app_state().connection_pooling_enabled,
87-
concurrent_outbound_requests_semaphore: ctx
87+
concurrent_outbound_connections_semaphore: ctx
8888
.app_state()
89-
.concurrent_outbound_requests_semaphore
89+
.concurrent_outbound_connections_semaphore
9090
.clone(),
9191
})
9292
}
@@ -112,8 +112,8 @@ pub struct InstanceState {
112112
wasi_http_clients: wasi::HttpClients,
113113
/// Whether connection pooling is enabled for this instance.
114114
connection_pooling_enabled: bool,
115-
/// A semaphore to limit the number of concurrent outbound requests.
116-
concurrent_outbound_requests_semaphore: Option<Arc<Semaphore>>,
115+
/// A semaphore to limit the number of concurrent outbound connections.
116+
concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
117117
}
118118

119119
impl InstanceState {
@@ -197,6 +197,6 @@ pub struct AppState {
197197
wasi_http_clients: wasi::HttpClients,
198198
/// Whether connection pooling is enabled for this app.
199199
connection_pooling_enabled: bool,
200-
/// A semaphore to limit the number of concurrent outbound requests.
201-
concurrent_outbound_requests_semaphore: Option<Arc<Semaphore>>,
200+
/// A semaphore to limit the number of concurrent outbound connections.
201+
concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
202202
}

crates/factor-outbound-http/src/runtime_config.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@ pub mod spin;
66
pub struct RuntimeConfig {
77
/// If true, enable connection pooling and reuse.
88
pub connection_pooling_enabled: bool,
9-
/// If set, limits the number of concurrent outbound requests.
10-
pub max_concurrent_requests: Option<usize>,
9+
/// If set, limits the number of concurrent outbound connections.
10+
pub max_concurrent_connections: Option<usize>,
1111
}
1212

1313
impl Default for RuntimeConfig {
1414
fn default() -> Self {
1515
Self {
1616
connection_pooling_enabled: true,
17-
max_concurrent_requests: None,
17+
max_concurrent_connections: None,
1818
}
1919
}
2020
}

crates/factor-outbound-http/src/runtime_config/spin.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pub fn config_from_table(
1616
let outbound_http_toml = outbound_http.clone().try_into::<OutboundHttpToml>()?;
1717
Ok(Some(super::RuntimeConfig {
1818
connection_pooling_enabled: outbound_http_toml.connection_pooling,
19-
max_concurrent_requests: outbound_http_toml.max_concurrent_requests,
19+
max_concurrent_connections: outbound_http_toml.max_concurrent_requests,
2020
}))
2121
} else {
2222
Ok(None)

crates/factor-outbound-http/src/spin.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ impl spin_http::Host for crate::InstanceState {
9898
});
9999

100100
// If we're limiting concurrent outbound requests, acquire a permit
101-
let permit = match &self.concurrent_outbound_requests_semaphore {
101+
let permit = match &self.concurrent_outbound_connections_semaphore {
102102
Some(s) => s.acquire().await.ok(),
103103
None => None,
104104
};

crates/factor-outbound-http/src/wasi.rs

Lines changed: 98 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ use spin_factors::{wasmtime::component::ResourceTable, RuntimeFactorsInstanceSta
2828
use tokio::{
2929
io::{AsyncRead, AsyncWrite, ReadBuf},
3030
net::TcpStream,
31-
sync::Semaphore,
31+
sync::{OwnedSemaphorePermit, Semaphore},
3232
time::timeout,
3333
};
3434
use tokio_rustls::client::TlsStream;
@@ -91,6 +91,9 @@ impl p3::WasiHttpCtx for InstanceState {
9191
self_request_origin: self.self_request_origin.clone(),
9292
blocked_networks: self.blocked_networks.clone(),
9393
http_clients: self.wasi_http_clients.clone(),
94+
concurrent_outbound_connections_semaphore: self
95+
.concurrent_outbound_connections_semaphore
96+
.clone(),
9497
};
9598
let config = OutgoingRequestConfig {
9699
use_tls: request.uri().scheme() == Some(&Scheme::HTTPS),
@@ -283,9 +286,9 @@ impl WasiHttpView for WasiHttpImplInner<'_> {
283286
self_request_origin: self.state.self_request_origin.clone(),
284287
blocked_networks: self.state.blocked_networks.clone(),
285288
http_clients: self.state.wasi_http_clients.clone(),
286-
concurrent_outbound_requests_semaphore: self
289+
concurrent_outbound_connections_semaphore: self
287290
.state
288-
.concurrent_outbound_requests_semaphore
291+
.concurrent_outbound_connections_semaphore
289292
.clone(),
290293
};
291294
Ok(HostFutureIncomingResponse::Pending(
@@ -312,7 +315,7 @@ struct RequestSender {
312315
self_request_origin: Option<SelfRequestOrigin>,
313316
request_interceptor: Option<Arc<dyn OutboundHttpInterceptor>>,
314317
http_clients: HttpClients,
315-
concurrent_outbound_requests_semaphore: Option<Arc<Semaphore>>,
318+
concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
316319
}
317320

318321
impl RequestSender {
@@ -454,12 +457,18 @@ impl RequestSender {
454457
None
455458
};
456459

460+
// If we're limiting concurrent outbound requests, acquire a permit
461+
let permit = match self.concurrent_outbound_connections_semaphore {
462+
Some(s) => s.acquire_owned().await.ok().map(Arc::new),
463+
None => None,
464+
};
457465
let resp = CONNECT_OPTIONS.scope(
458466
ConnectOptions {
459467
blocked_networks: self.blocked_networks,
460468
connect_timeout,
461469
tls_client_config,
462470
override_connect_addr,
471+
permit,
463472
},
464473
async move {
465474
if use_tls {
@@ -480,17 +489,11 @@ impl RequestSender {
480489
},
481490
);
482491

483-
// If we're limiting concurrent outbound requests, acquire a permit
484-
let permit = match &self.concurrent_outbound_requests_semaphore {
485-
Some(s) => s.acquire().await.ok(),
486-
None => None,
487-
};
488492
let resp = timeout(first_byte_timeout, resp)
489493
.await
490494
.map_err(|_| ErrorCode::ConnectionReadTimeout)?
491495
.map_err(hyper_legacy_request_error)?
492496
.map(|body| body.map_err(hyper_request_error).boxed());
493-
drop(permit);
494497

495498
tracing::Span::current().record("http.response.status_code", resp.status().as_u16());
496499

@@ -532,24 +535,40 @@ impl HttpClients {
532535
}
533536
}
534537

535-
// We must use task-local variables for these config options when using
536-
// `hyper_util::client::legacy::Client::request` because there's no way to plumb
537-
// them through as parameters. Moreover, if there's already a pooled connection
538-
// ready, we'll reuse that and ignore these options anyway.
539538
tokio::task_local! {
539+
/// The options used when establishing a new connection.
540+
///
541+
/// We must use task-local variables for these config options when using
542+
/// `hyper_util::client::legacy::Client::request` because there's no way to plumb
543+
/// them through as parameters. Moreover, if there's already a pooled connection
544+
/// ready, we'll reuse that and ignore these options anyway. After each connection
545+
/// is established, the options are dropped.
540546
static CONNECT_OPTIONS: ConnectOptions;
541547
}
542548

543549
#[derive(Clone)]
544550
struct ConnectOptions {
551+
/// The blocked networks configuration.
545552
blocked_networks: BlockedNetworks,
553+
/// Timeout for establishing a TCP connection.
546554
connect_timeout: Duration,
555+
/// TLS client configuration to use, if any.
547556
tls_client_config: Option<TlsClientConfig>,
557+
/// If set, override the address to connect to instead of using the given `uri`'s authority.
548558
override_connect_addr: Option<SocketAddr>,
559+
/// A permit for this connection
560+
///
561+
/// If there is a permit, it should be dropped when the connection is closed.
562+
permit: Option<Arc<OwnedSemaphorePermit>>,
549563
}
550564

551565
impl ConnectOptions {
552-
async fn connect_tcp(&self, uri: &Uri, default_port: u16) -> Result<TcpStream, ErrorCode> {
566+
/// Establish a TCP connection to the given URI and default port.
567+
async fn connect_tcp(
568+
&self,
569+
uri: &Uri,
570+
default_port: u16,
571+
) -> Result<PermittedTcpStream, ErrorCode> {
553572
let mut socket_addrs = match self.override_connect_addr {
554573
Some(override_connect_addr) => vec![override_connect_addr],
555574
None => {
@@ -584,22 +603,27 @@ impl ConnectOptions {
584603
return Err(ErrorCode::DestinationIpProhibited);
585604
}
586605

587-
timeout(self.connect_timeout, TcpStream::connect(&*socket_addrs))
606+
let stream = timeout(self.connect_timeout, TcpStream::connect(&*socket_addrs))
588607
.await
589608
.map_err(|_| ErrorCode::ConnectionTimeout)?
590609
.map_err(|err| match err.kind() {
591610
std::io::ErrorKind::AddrNotAvailable => {
592611
dns_error("address not available".into(), 0)
593612
}
594613
_ => ErrorCode::ConnectionRefused,
595-
})
614+
})?;
615+
Ok(PermittedTcpStream {
616+
inner: stream,
617+
_permit: self.permit.clone(),
618+
})
596619
}
597620

621+
/// Establish a TLS connection to the given URI and default port.
598622
async fn connect_tls(
599623
&self,
600624
uri: &Uri,
601625
default_port: u16,
602-
) -> Result<TlsStream<TcpStream>, ErrorCode> {
626+
) -> Result<TlsStream<PermittedTcpStream>, ErrorCode> {
603627
let tcp_stream = self.connect_tcp(uri, default_port).await?;
604628

605629
let mut tls_client_config = self.tls_client_config.as_deref().unwrap().clone();
@@ -609,7 +633,7 @@ impl ConnectOptions {
609633
let domain = rustls::pki_types::ServerName::try_from(uri.host().unwrap())
610634
.map_err(|e| {
611635
tracing::warn!("dns lookup error: {e:?}");
612-
dns_error("invalid dns name".to_string(), 0)
636+
dns_error("invalid dns name".into(), 0)
613637
})?
614638
.to_owned();
615639
connector.connect(domain, tcp_stream).await.map_err(|e| {
@@ -619,20 +643,22 @@ impl ConnectOptions {
619643
}
620644
}
621645

646+
/// A connector the uses `ConnectOptions`
622647
#[derive(Clone)]
623648
struct HttpConnector;
624649

625650
impl HttpConnector {
626-
async fn connect(uri: Uri) -> Result<TokioIo<TcpStream>, ErrorCode> {
651+
async fn connect(uri: Uri) -> Result<TokioIo<PermittedTcpStream>, ErrorCode> {
627652
let stream = CONNECT_OPTIONS.get().connect_tcp(&uri, 80).await?;
628653
Ok(TokioIo::new(stream))
629654
}
630655
}
631656

632657
impl Service<Uri> for HttpConnector {
633-
type Response = TokioIo<TcpStream>;
658+
type Response = TokioIo<PermittedTcpStream>;
634659
type Error = ErrorCode;
635-
type Future = Pin<Box<dyn Future<Output = Result<TokioIo<TcpStream>, ErrorCode>> + Send>>;
660+
type Future =
661+
Pin<Box<dyn Future<Output = Result<TokioIo<PermittedTcpStream>, ErrorCode>> + Send>>;
636662

637663
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
638664
Poll::Ready(Ok(()))
@@ -643,6 +669,7 @@ impl Service<Uri> for HttpConnector {
643669
}
644670
}
645671

672+
/// A connector that establishes TLS connections using `rustls` and `ConnectOptions`.
646673
#[derive(Clone)]
647674
struct HttpsConnector;
648675

@@ -667,7 +694,7 @@ impl Service<Uri> for HttpsConnector {
667694
}
668695
}
669696

670-
struct RustlsStream(TlsStream<TcpStream>);
697+
struct RustlsStream(TlsStream<PermittedTcpStream>);
671698

672699
impl Connection for RustlsStream {
673700
fn connected(&self) -> Connected {
@@ -722,6 +749,54 @@ impl AsyncWrite for RustlsStream {
722749
}
723750
}
724751

752+
/// A TCP stream that holds an optional permit indicating that it is allowed to exist.
753+
struct PermittedTcpStream {
754+
/// The wrapped TCP stream.
755+
inner: TcpStream,
756+
/// A permit indicating that this stream is allowed to exist.
757+
///
758+
/// When this stream is dropped, the permit is also dropped, allowing another
759+
/// connection to be established.
760+
_permit: Option<Arc<OwnedSemaphorePermit>>,
761+
}
762+
763+
impl Connection for PermittedTcpStream {
764+
fn connected(&self) -> Connected {
765+
self.inner.connected()
766+
}
767+
}
768+
769+
impl AsyncRead for PermittedTcpStream {
770+
fn poll_read(
771+
self: Pin<&mut Self>,
772+
cx: &mut Context<'_>,
773+
buf: &mut ReadBuf<'_>,
774+
) -> Poll<std::io::Result<()>> {
775+
Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
776+
}
777+
}
778+
779+
impl AsyncWrite for PermittedTcpStream {
780+
fn poll_write(
781+
self: Pin<&mut Self>,
782+
cx: &mut Context<'_>,
783+
buf: &[u8],
784+
) -> Poll<Result<usize, std::io::Error>> {
785+
Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
786+
}
787+
788+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
789+
Pin::new(&mut self.get_mut().inner).poll_flush(cx)
790+
}
791+
792+
fn poll_shutdown(
793+
self: Pin<&mut Self>,
794+
cx: &mut Context<'_>,
795+
) -> Poll<Result<(), std::io::Error>> {
796+
Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
797+
}
798+
}
799+
725800
/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request.
726801
fn hyper_request_error(err: hyper::Error) -> ErrorCode {
727802
// If there's a source, we might be able to extract a wasi-http error from it.

0 commit comments

Comments
 (0)