@@ -28,7 +28,7 @@ use spin_factors::{wasmtime::component::ResourceTable, RuntimeFactorsInstanceSta
2828use tokio:: {
2929 io:: { AsyncRead , AsyncWrite , ReadBuf } ,
3030 net:: TcpStream ,
31- sync:: Semaphore ,
31+ sync:: { OwnedSemaphorePermit , Semaphore } ,
3232 time:: timeout,
3333} ;
3434use tokio_rustls:: client:: TlsStream ;
@@ -91,6 +91,9 @@ impl p3::WasiHttpCtx for InstanceState {
9191 self_request_origin : self . self_request_origin . clone ( ) ,
9292 blocked_networks : self . blocked_networks . clone ( ) ,
9393 http_clients : self . wasi_http_clients . clone ( ) ,
94+ concurrent_outbound_connections_semaphore : self
95+ . concurrent_outbound_connections_semaphore
96+ . clone ( ) ,
9497 } ;
9598 let config = OutgoingRequestConfig {
9699 use_tls : request. uri ( ) . scheme ( ) == Some ( & Scheme :: HTTPS ) ,
@@ -283,9 +286,9 @@ impl WasiHttpView for WasiHttpImplInner<'_> {
283286 self_request_origin : self . state . self_request_origin . clone ( ) ,
284287 blocked_networks : self . state . blocked_networks . clone ( ) ,
285288 http_clients : self . state . wasi_http_clients . clone ( ) ,
286- concurrent_outbound_requests_semaphore : self
289+ concurrent_outbound_connections_semaphore : self
287290 . state
288- . concurrent_outbound_requests_semaphore
291+ . concurrent_outbound_connections_semaphore
289292 . clone ( ) ,
290293 } ;
291294 Ok ( HostFutureIncomingResponse :: Pending (
@@ -312,7 +315,7 @@ struct RequestSender {
312315 self_request_origin : Option < SelfRequestOrigin > ,
313316 request_interceptor : Option < Arc < dyn OutboundHttpInterceptor > > ,
314317 http_clients : HttpClients ,
315- concurrent_outbound_requests_semaphore : Option < Arc < Semaphore > > ,
318+ concurrent_outbound_connections_semaphore : Option < Arc < Semaphore > > ,
316319}
317320
318321impl RequestSender {
@@ -454,12 +457,18 @@ impl RequestSender {
454457 None
455458 } ;
456459
460+ // If we're limiting concurrent outbound requests, acquire a permit
461+ let permit = match self . concurrent_outbound_connections_semaphore {
462+ Some ( s) => s. acquire_owned ( ) . await . ok ( ) . map ( Arc :: new) ,
463+ None => None ,
464+ } ;
457465 let resp = CONNECT_OPTIONS . scope (
458466 ConnectOptions {
459467 blocked_networks : self . blocked_networks ,
460468 connect_timeout,
461469 tls_client_config,
462470 override_connect_addr,
471+ permit,
463472 } ,
464473 async move {
465474 if use_tls {
@@ -480,17 +489,11 @@ impl RequestSender {
480489 } ,
481490 ) ;
482491
483- // If we're limiting concurrent outbound requests, acquire a permit
484- let permit = match & self . concurrent_outbound_requests_semaphore {
485- Some ( s) => s. acquire ( ) . await . ok ( ) ,
486- None => None ,
487- } ;
488492 let resp = timeout ( first_byte_timeout, resp)
489493 . await
490494 . map_err ( |_| ErrorCode :: ConnectionReadTimeout ) ?
491495 . map_err ( hyper_legacy_request_error) ?
492496 . map ( |body| body. map_err ( hyper_request_error) . boxed ( ) ) ;
493- drop ( permit) ;
494497
495498 tracing:: Span :: current ( ) . record ( "http.response.status_code" , resp. status ( ) . as_u16 ( ) ) ;
496499
@@ -532,24 +535,40 @@ impl HttpClients {
532535 }
533536}
534537
535- // We must use task-local variables for these config options when using
536- // `hyper_util::client::legacy::Client::request` because there's no way to plumb
537- // them through as parameters. Moreover, if there's already a pooled connection
538- // ready, we'll reuse that and ignore these options anyway.
539538tokio:: task_local! {
539+ /// The options used when establishing a new connection.
540+ ///
541+ /// We must use task-local variables for these config options when using
542+ /// `hyper_util::client::legacy::Client::request` because there's no way to plumb
543+ /// them through as parameters. Moreover, if there's already a pooled connection
544+ /// ready, we'll reuse that and ignore these options anyway. After each connection
545+ /// is established, the options are dropped.
540546 static CONNECT_OPTIONS : ConnectOptions ;
541547}
542548
543549#[ derive( Clone ) ]
544550struct ConnectOptions {
551+ /// The blocked networks configuration.
545552 blocked_networks : BlockedNetworks ,
553+ /// Timeout for establishing a TCP connection.
546554 connect_timeout : Duration ,
555+ /// TLS client configuration to use, if any.
547556 tls_client_config : Option < TlsClientConfig > ,
557+ /// If set, override the address to connect to instead of using the given `uri`'s authority.
548558 override_connect_addr : Option < SocketAddr > ,
559+ /// A permit for this connection
560+ ///
561+ /// If there is a permit, it should be dropped when the connection is closed.
562+ permit : Option < Arc < OwnedSemaphorePermit > > ,
549563}
550564
551565impl ConnectOptions {
552- async fn connect_tcp ( & self , uri : & Uri , default_port : u16 ) -> Result < TcpStream , ErrorCode > {
566+ /// Establish a TCP connection to the given URI and default port.
567+ async fn connect_tcp (
568+ & self ,
569+ uri : & Uri ,
570+ default_port : u16 ,
571+ ) -> Result < PermittedTcpStream , ErrorCode > {
553572 let mut socket_addrs = match self . override_connect_addr {
554573 Some ( override_connect_addr) => vec ! [ override_connect_addr] ,
555574 None => {
@@ -584,22 +603,27 @@ impl ConnectOptions {
584603 return Err ( ErrorCode :: DestinationIpProhibited ) ;
585604 }
586605
587- timeout ( self . connect_timeout , TcpStream :: connect ( & * socket_addrs) )
606+ let stream = timeout ( self . connect_timeout , TcpStream :: connect ( & * socket_addrs) )
588607 . await
589608 . map_err ( |_| ErrorCode :: ConnectionTimeout ) ?
590609 . map_err ( |err| match err. kind ( ) {
591610 std:: io:: ErrorKind :: AddrNotAvailable => {
592611 dns_error ( "address not available" . into ( ) , 0 )
593612 }
594613 _ => ErrorCode :: ConnectionRefused ,
595- } )
614+ } ) ?;
615+ Ok ( PermittedTcpStream {
616+ inner : stream,
617+ _permit : self . permit . clone ( ) ,
618+ } )
596619 }
597620
621+ /// Establish a TLS connection to the given URI and default port.
598622 async fn connect_tls (
599623 & self ,
600624 uri : & Uri ,
601625 default_port : u16 ,
602- ) -> Result < TlsStream < TcpStream > , ErrorCode > {
626+ ) -> Result < TlsStream < PermittedTcpStream > , ErrorCode > {
603627 let tcp_stream = self . connect_tcp ( uri, default_port) . await ?;
604628
605629 let mut tls_client_config = self . tls_client_config . as_deref ( ) . unwrap ( ) . clone ( ) ;
@@ -609,7 +633,7 @@ impl ConnectOptions {
609633 let domain = rustls:: pki_types:: ServerName :: try_from ( uri. host ( ) . unwrap ( ) )
610634 . map_err ( |e| {
611635 tracing:: warn!( "dns lookup error: {e:?}" ) ;
612- dns_error ( "invalid dns name" . to_string ( ) , 0 )
636+ dns_error ( "invalid dns name" . into ( ) , 0 )
613637 } ) ?
614638 . to_owned ( ) ;
615639 connector. connect ( domain, tcp_stream) . await . map_err ( |e| {
@@ -619,20 +643,22 @@ impl ConnectOptions {
619643 }
620644}
621645
646+ /// A connector the uses `ConnectOptions`
622647#[ derive( Clone ) ]
623648struct HttpConnector ;
624649
625650impl HttpConnector {
626- async fn connect ( uri : Uri ) -> Result < TokioIo < TcpStream > , ErrorCode > {
651+ async fn connect ( uri : Uri ) -> Result < TokioIo < PermittedTcpStream > , ErrorCode > {
627652 let stream = CONNECT_OPTIONS . get ( ) . connect_tcp ( & uri, 80 ) . await ?;
628653 Ok ( TokioIo :: new ( stream) )
629654 }
630655}
631656
632657impl Service < Uri > for HttpConnector {
633- type Response = TokioIo < TcpStream > ;
658+ type Response = TokioIo < PermittedTcpStream > ;
634659 type Error = ErrorCode ;
635- type Future = Pin < Box < dyn Future < Output = Result < TokioIo < TcpStream > , ErrorCode > > + Send > > ;
660+ type Future =
661+ Pin < Box < dyn Future < Output = Result < TokioIo < PermittedTcpStream > , ErrorCode > > + Send > > ;
636662
637663 fn poll_ready ( & mut self , _cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , Self :: Error > > {
638664 Poll :: Ready ( Ok ( ( ) ) )
@@ -643,6 +669,7 @@ impl Service<Uri> for HttpConnector {
643669 }
644670}
645671
672+ /// A connector that establishes TLS connections using `rustls` and `ConnectOptions`.
646673#[ derive( Clone ) ]
647674struct HttpsConnector ;
648675
@@ -667,7 +694,7 @@ impl Service<Uri> for HttpsConnector {
667694 }
668695}
669696
670- struct RustlsStream ( TlsStream < TcpStream > ) ;
697+ struct RustlsStream ( TlsStream < PermittedTcpStream > ) ;
671698
672699impl Connection for RustlsStream {
673700 fn connected ( & self ) -> Connected {
@@ -722,6 +749,54 @@ impl AsyncWrite for RustlsStream {
722749 }
723750}
724751
752+ /// A TCP stream that holds an optional permit indicating that it is allowed to exist.
753+ struct PermittedTcpStream {
754+ /// The wrapped TCP stream.
755+ inner : TcpStream ,
756+ /// A permit indicating that this stream is allowed to exist.
757+ ///
758+ /// When this stream is dropped, the permit is also dropped, allowing another
759+ /// connection to be established.
760+ _permit : Option < Arc < OwnedSemaphorePermit > > ,
761+ }
762+
763+ impl Connection for PermittedTcpStream {
764+ fn connected ( & self ) -> Connected {
765+ self . inner . connected ( )
766+ }
767+ }
768+
769+ impl AsyncRead for PermittedTcpStream {
770+ fn poll_read (
771+ self : Pin < & mut Self > ,
772+ cx : & mut Context < ' _ > ,
773+ buf : & mut ReadBuf < ' _ > ,
774+ ) -> Poll < std:: io:: Result < ( ) > > {
775+ Pin :: new ( & mut self . get_mut ( ) . inner ) . poll_read ( cx, buf)
776+ }
777+ }
778+
779+ impl AsyncWrite for PermittedTcpStream {
780+ fn poll_write (
781+ self : Pin < & mut Self > ,
782+ cx : & mut Context < ' _ > ,
783+ buf : & [ u8 ] ,
784+ ) -> Poll < Result < usize , std:: io:: Error > > {
785+ Pin :: new ( & mut self . get_mut ( ) . inner ) . poll_write ( cx, buf)
786+ }
787+
788+ fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , std:: io:: Error > > {
789+ Pin :: new ( & mut self . get_mut ( ) . inner ) . poll_flush ( cx)
790+ }
791+
792+ fn poll_shutdown (
793+ self : Pin < & mut Self > ,
794+ cx : & mut Context < ' _ > ,
795+ ) -> Poll < Result < ( ) , std:: io:: Error > > {
796+ Pin :: new ( & mut self . get_mut ( ) . inner ) . poll_shutdown ( cx)
797+ }
798+ }
799+
725800/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request.
726801fn hyper_request_error ( err : hyper:: Error ) -> ErrorCode {
727802 // If there's a source, we might be able to extract a wasi-http error from it.
0 commit comments