@@ -28,7 +28,7 @@ use spin_factors::{wasmtime::component::ResourceTable, RuntimeFactorsInstanceSta
28
28
use tokio:: {
29
29
io:: { AsyncRead , AsyncWrite , ReadBuf } ,
30
30
net:: TcpStream ,
31
- sync:: Semaphore ,
31
+ sync:: { OwnedSemaphorePermit , Semaphore } ,
32
32
time:: timeout,
33
33
} ;
34
34
use tokio_rustls:: client:: TlsStream ;
@@ -91,6 +91,9 @@ impl p3::WasiHttpCtx for InstanceState {
91
91
self_request_origin : self . self_request_origin . clone ( ) ,
92
92
blocked_networks : self . blocked_networks . clone ( ) ,
93
93
http_clients : self . wasi_http_clients . clone ( ) ,
94
+ concurrent_outbound_connections_semaphore : self
95
+ . concurrent_outbound_connections_semaphore
96
+ . clone ( ) ,
94
97
} ;
95
98
let config = OutgoingRequestConfig {
96
99
use_tls : request. uri ( ) . scheme ( ) == Some ( & Scheme :: HTTPS ) ,
@@ -283,9 +286,9 @@ impl WasiHttpView for WasiHttpImplInner<'_> {
283
286
self_request_origin : self . state . self_request_origin . clone ( ) ,
284
287
blocked_networks : self . state . blocked_networks . clone ( ) ,
285
288
http_clients : self . state . wasi_http_clients . clone ( ) ,
286
- concurrent_outbound_requests_semaphore : self
289
+ concurrent_outbound_connections_semaphore : self
287
290
. state
288
- . concurrent_outbound_requests_semaphore
291
+ . concurrent_outbound_connections_semaphore
289
292
. clone ( ) ,
290
293
} ;
291
294
Ok ( HostFutureIncomingResponse :: Pending (
@@ -312,7 +315,7 @@ struct RequestSender {
312
315
self_request_origin : Option < SelfRequestOrigin > ,
313
316
request_interceptor : Option < Arc < dyn OutboundHttpInterceptor > > ,
314
317
http_clients : HttpClients ,
315
- concurrent_outbound_requests_semaphore : Option < Arc < Semaphore > > ,
318
+ concurrent_outbound_connections_semaphore : Option < Arc < Semaphore > > ,
316
319
}
317
320
318
321
impl RequestSender {
@@ -454,12 +457,18 @@ impl RequestSender {
454
457
None
455
458
} ;
456
459
460
+ // If we're limiting concurrent outbound requests, acquire a permit
461
+ let permit = match self . concurrent_outbound_connections_semaphore {
462
+ Some ( s) => s. acquire_owned ( ) . await . ok ( ) . map ( Arc :: new) ,
463
+ None => None ,
464
+ } ;
457
465
let resp = CONNECT_OPTIONS . scope (
458
466
ConnectOptions {
459
467
blocked_networks : self . blocked_networks ,
460
468
connect_timeout,
461
469
tls_client_config,
462
470
override_connect_addr,
471
+ permit,
463
472
} ,
464
473
async move {
465
474
if use_tls {
@@ -480,17 +489,11 @@ impl RequestSender {
480
489
} ,
481
490
) ;
482
491
483
- // If we're limiting concurrent outbound requests, acquire a permit
484
- let permit = match & self . concurrent_outbound_requests_semaphore {
485
- Some ( s) => s. acquire ( ) . await . ok ( ) ,
486
- None => None ,
487
- } ;
488
492
let resp = timeout ( first_byte_timeout, resp)
489
493
. await
490
494
. map_err ( |_| ErrorCode :: ConnectionReadTimeout ) ?
491
495
. map_err ( hyper_legacy_request_error) ?
492
496
. map ( |body| body. map_err ( hyper_request_error) . boxed ( ) ) ;
493
- drop ( permit) ;
494
497
495
498
tracing:: Span :: current ( ) . record ( "http.response.status_code" , resp. status ( ) . as_u16 ( ) ) ;
496
499
@@ -532,24 +535,40 @@ impl HttpClients {
532
535
}
533
536
}
534
537
535
- // We must use task-local variables for these config options when using
536
- // `hyper_util::client::legacy::Client::request` because there's no way to plumb
537
- // them through as parameters. Moreover, if there's already a pooled connection
538
- // ready, we'll reuse that and ignore these options anyway.
539
538
tokio:: task_local! {
539
+ /// The options used when establishing a new connection.
540
+ ///
541
+ /// We must use task-local variables for these config options when using
542
+ /// `hyper_util::client::legacy::Client::request` because there's no way to plumb
543
+ /// them through as parameters. Moreover, if there's already a pooled connection
544
+ /// ready, we'll reuse that and ignore these options anyway. After each connection
545
+ /// is established, the options are dropped.
540
546
static CONNECT_OPTIONS : ConnectOptions ;
541
547
}
542
548
543
549
#[ derive( Clone ) ]
544
550
struct ConnectOptions {
551
+ /// The blocked networks configuration.
545
552
blocked_networks : BlockedNetworks ,
553
+ /// Timeout for establishing a TCP connection.
546
554
connect_timeout : Duration ,
555
+ /// TLS client configuration to use, if any.
547
556
tls_client_config : Option < TlsClientConfig > ,
557
+ /// If set, override the address to connect to instead of using the given `uri`'s authority.
548
558
override_connect_addr : Option < SocketAddr > ,
559
+ /// A permit for this connection
560
+ ///
561
+ /// If there is a permit, it should be dropped when the connection is closed.
562
+ permit : Option < Arc < OwnedSemaphorePermit > > ,
549
563
}
550
564
551
565
impl ConnectOptions {
552
- async fn connect_tcp ( & self , uri : & Uri , default_port : u16 ) -> Result < TcpStream , ErrorCode > {
566
+ /// Establish a TCP connection to the given URI and default port.
567
+ async fn connect_tcp (
568
+ & self ,
569
+ uri : & Uri ,
570
+ default_port : u16 ,
571
+ ) -> Result < PermittedTcpStream , ErrorCode > {
553
572
let mut socket_addrs = match self . override_connect_addr {
554
573
Some ( override_connect_addr) => vec ! [ override_connect_addr] ,
555
574
None => {
@@ -584,22 +603,27 @@ impl ConnectOptions {
584
603
return Err ( ErrorCode :: DestinationIpProhibited ) ;
585
604
}
586
605
587
- timeout ( self . connect_timeout , TcpStream :: connect ( & * socket_addrs) )
606
+ let stream = timeout ( self . connect_timeout , TcpStream :: connect ( & * socket_addrs) )
588
607
. await
589
608
. map_err ( |_| ErrorCode :: ConnectionTimeout ) ?
590
609
. map_err ( |err| match err. kind ( ) {
591
610
std:: io:: ErrorKind :: AddrNotAvailable => {
592
611
dns_error ( "address not available" . into ( ) , 0 )
593
612
}
594
613
_ => ErrorCode :: ConnectionRefused ,
595
- } )
614
+ } ) ?;
615
+ Ok ( PermittedTcpStream {
616
+ inner : stream,
617
+ _permit : self . permit . clone ( ) ,
618
+ } )
596
619
}
597
620
621
+ /// Establish a TLS connection to the given URI and default port.
598
622
async fn connect_tls (
599
623
& self ,
600
624
uri : & Uri ,
601
625
default_port : u16 ,
602
- ) -> Result < TlsStream < TcpStream > , ErrorCode > {
626
+ ) -> Result < TlsStream < PermittedTcpStream > , ErrorCode > {
603
627
let tcp_stream = self . connect_tcp ( uri, default_port) . await ?;
604
628
605
629
let mut tls_client_config = self . tls_client_config . as_deref ( ) . unwrap ( ) . clone ( ) ;
@@ -609,7 +633,7 @@ impl ConnectOptions {
609
633
let domain = rustls:: pki_types:: ServerName :: try_from ( uri. host ( ) . unwrap ( ) )
610
634
. map_err ( |e| {
611
635
tracing:: warn!( "dns lookup error: {e:?}" ) ;
612
- dns_error ( "invalid dns name" . to_string ( ) , 0 )
636
+ dns_error ( "invalid dns name" . into ( ) , 0 )
613
637
} ) ?
614
638
. to_owned ( ) ;
615
639
connector. connect ( domain, tcp_stream) . await . map_err ( |e| {
@@ -619,20 +643,22 @@ impl ConnectOptions {
619
643
}
620
644
}
621
645
646
+ /// A connector the uses `ConnectOptions`
622
647
#[ derive( Clone ) ]
623
648
struct HttpConnector ;
624
649
625
650
impl HttpConnector {
626
- async fn connect ( uri : Uri ) -> Result < TokioIo < TcpStream > , ErrorCode > {
651
+ async fn connect ( uri : Uri ) -> Result < TokioIo < PermittedTcpStream > , ErrorCode > {
627
652
let stream = CONNECT_OPTIONS . get ( ) . connect_tcp ( & uri, 80 ) . await ?;
628
653
Ok ( TokioIo :: new ( stream) )
629
654
}
630
655
}
631
656
632
657
impl Service < Uri > for HttpConnector {
633
- type Response = TokioIo < TcpStream > ;
658
+ type Response = TokioIo < PermittedTcpStream > ;
634
659
type Error = ErrorCode ;
635
- type Future = Pin < Box < dyn Future < Output = Result < TokioIo < TcpStream > , ErrorCode > > + Send > > ;
660
+ type Future =
661
+ Pin < Box < dyn Future < Output = Result < TokioIo < PermittedTcpStream > , ErrorCode > > + Send > > ;
636
662
637
663
fn poll_ready ( & mut self , _cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , Self :: Error > > {
638
664
Poll :: Ready ( Ok ( ( ) ) )
@@ -643,6 +669,7 @@ impl Service<Uri> for HttpConnector {
643
669
}
644
670
}
645
671
672
+ /// A connector that establishes TLS connections using `rustls` and `ConnectOptions`.
646
673
#[ derive( Clone ) ]
647
674
struct HttpsConnector ;
648
675
@@ -667,7 +694,7 @@ impl Service<Uri> for HttpsConnector {
667
694
}
668
695
}
669
696
670
- struct RustlsStream ( TlsStream < TcpStream > ) ;
697
+ struct RustlsStream ( TlsStream < PermittedTcpStream > ) ;
671
698
672
699
impl Connection for RustlsStream {
673
700
fn connected ( & self ) -> Connected {
@@ -722,6 +749,54 @@ impl AsyncWrite for RustlsStream {
722
749
}
723
750
}
724
751
752
+ /// A TCP stream that holds an optional permit indicating that it is allowed to exist.
753
+ struct PermittedTcpStream {
754
+ /// The wrapped TCP stream.
755
+ inner : TcpStream ,
756
+ /// A permit indicating that this stream is allowed to exist.
757
+ ///
758
+ /// When this stream is dropped, the permit is also dropped, allowing another
759
+ /// connection to be established.
760
+ _permit : Option < Arc < OwnedSemaphorePermit > > ,
761
+ }
762
+
763
+ impl Connection for PermittedTcpStream {
764
+ fn connected ( & self ) -> Connected {
765
+ self . inner . connected ( )
766
+ }
767
+ }
768
+
769
+ impl AsyncRead for PermittedTcpStream {
770
+ fn poll_read (
771
+ self : Pin < & mut Self > ,
772
+ cx : & mut Context < ' _ > ,
773
+ buf : & mut ReadBuf < ' _ > ,
774
+ ) -> Poll < std:: io:: Result < ( ) > > {
775
+ Pin :: new ( & mut self . get_mut ( ) . inner ) . poll_read ( cx, buf)
776
+ }
777
+ }
778
+
779
+ impl AsyncWrite for PermittedTcpStream {
780
+ fn poll_write (
781
+ self : Pin < & mut Self > ,
782
+ cx : & mut Context < ' _ > ,
783
+ buf : & [ u8 ] ,
784
+ ) -> Poll < Result < usize , std:: io:: Error > > {
785
+ Pin :: new ( & mut self . get_mut ( ) . inner ) . poll_write ( cx, buf)
786
+ }
787
+
788
+ fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , std:: io:: Error > > {
789
+ Pin :: new ( & mut self . get_mut ( ) . inner ) . poll_flush ( cx)
790
+ }
791
+
792
+ fn poll_shutdown (
793
+ self : Pin < & mut Self > ,
794
+ cx : & mut Context < ' _ > ,
795
+ ) -> Poll < Result < ( ) , std:: io:: Error > > {
796
+ Pin :: new ( & mut self . get_mut ( ) . inner ) . poll_shutdown ( cx)
797
+ }
798
+ }
799
+
725
800
/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request.
726
801
fn hyper_request_error ( err : hyper:: Error ) -> ErrorCode {
727
802
// If there's a source, we might be able to extract a wasi-http error from it.
0 commit comments