@@ -26,7 +26,7 @@ use spin_factors::{wasmtime::component::ResourceTable, RuntimeFactorsInstanceSta
26
26
use tokio:: {
27
27
io:: { AsyncRead , AsyncWrite , ReadBuf } ,
28
28
net:: TcpStream ,
29
- sync:: Semaphore ,
29
+ sync:: { OwnedSemaphorePermit , Semaphore } ,
30
30
time:: timeout,
31
31
} ;
32
32
use tokio_rustls:: client:: TlsStream ;
@@ -129,9 +129,9 @@ impl WasiHttpView for WasiHttpImplInner<'_> {
129
129
self_request_origin : self . state . self_request_origin . clone ( ) ,
130
130
blocked_networks : self . state . blocked_networks . clone ( ) ,
131
131
http_clients : self . state . wasi_http_clients . clone ( ) ,
132
- concurrent_outbound_requests_semaphore : self
132
+ concurrent_outbound_connections_semaphore : self
133
133
. state
134
- . concurrent_outbound_requests_semaphore
134
+ . concurrent_outbound_connections_semaphore
135
135
. clone ( ) ,
136
136
} ;
137
137
Ok ( HostFutureIncomingResponse :: Pending (
@@ -158,7 +158,7 @@ struct RequestSender {
158
158
self_request_origin : Option < SelfRequestOrigin > ,
159
159
request_interceptor : Option < Arc < dyn OutboundHttpInterceptor > > ,
160
160
http_clients : HttpClients ,
161
- concurrent_outbound_requests_semaphore : Option < Arc < Semaphore > > ,
161
+ concurrent_outbound_connections_semaphore : Option < Arc < Semaphore > > ,
162
162
}
163
163
164
164
impl RequestSender {
@@ -300,12 +300,18 @@ impl RequestSender {
300
300
None
301
301
} ;
302
302
303
+ // If we're limiting concurrent outbound requests, acquire a permit
304
+ let permit = match self . concurrent_outbound_connections_semaphore {
305
+ Some ( s) => s. acquire_owned ( ) . await . ok ( ) . map ( Arc :: new) ,
306
+ None => None ,
307
+ } ;
303
308
let resp = CONNECT_OPTIONS . scope (
304
309
ConnectOptions {
305
310
blocked_networks : self . blocked_networks ,
306
311
connect_timeout,
307
312
tls_client_config,
308
313
override_connect_addr,
314
+ permit,
309
315
} ,
310
316
async move {
311
317
if use_tls {
@@ -326,17 +332,11 @@ impl RequestSender {
326
332
} ,
327
333
) ;
328
334
329
- // If we're limiting concurrent outbound requests, acquire a permit
330
- let permit = match & self . concurrent_outbound_requests_semaphore {
331
- Some ( s) => s. acquire ( ) . await . ok ( ) ,
332
- None => None ,
333
- } ;
334
335
let resp = timeout ( first_byte_timeout, resp)
335
336
. await
336
337
. map_err ( |_| ErrorCode :: ConnectionReadTimeout ) ?
337
338
. map_err ( hyper_legacy_request_error) ?
338
339
. map ( |body| body. map_err ( hyper_request_error) . boxed ( ) ) ;
339
- drop ( permit) ;
340
340
341
341
tracing:: Span :: current ( ) . record ( "http.response.status_code" , resp. status ( ) . as_u16 ( ) ) ;
342
342
@@ -378,24 +378,40 @@ impl HttpClients {
378
378
}
379
379
}
380
380
381
- // We must use task-local variables for these config options when using
382
- // `hyper_util::client::legacy::Client::request` because there's no way to plumb
383
- // them through as parameters. Moreover, if there's already a pooled connection
384
- // ready, we'll reuse that and ignore these options anyway.
385
381
tokio:: task_local! {
382
+ /// The options used when establishing a new connection.
383
+ ///
384
+ /// We must use task-local variables for these config options when using
385
+ /// `hyper_util::client::legacy::Client::request` because there's no way to plumb
386
+ /// them through as parameters. Moreover, if there's already a pooled connection
387
+ /// ready, we'll reuse that and ignore these options anyway. After each connection
388
+ /// is established, the options are dropped.
386
389
static CONNECT_OPTIONS : ConnectOptions ;
387
390
}
388
391
389
392
#[ derive( Clone ) ]
390
393
struct ConnectOptions {
394
+ /// The blocked networks configuration.
391
395
blocked_networks : BlockedNetworks ,
396
+ /// Timeout for establishing a TCP connection.
392
397
connect_timeout : Duration ,
398
+ /// TLS client configuration to use, if any.
393
399
tls_client_config : Option < TlsClientConfig > ,
400
+ /// If set, override the address to connect to instead of using the given `uri`'s authority.
394
401
override_connect_addr : Option < SocketAddr > ,
402
+ /// A permit for this connection
403
+ ///
404
+ /// If there is a permit, it should be dropped when the connection is closed.
405
+ permit : Option < Arc < OwnedSemaphorePermit > > ,
395
406
}
396
407
397
408
impl ConnectOptions {
398
- async fn connect_tcp ( & self , uri : & Uri , default_port : u16 ) -> Result < TcpStream , ErrorCode > {
409
+ /// Establish a TCP connection to the given URI and default port.
410
+ async fn connect_tcp (
411
+ & self ,
412
+ uri : & Uri ,
413
+ default_port : u16 ,
414
+ ) -> Result < PermittedTcpStream , ErrorCode > {
399
415
let mut socket_addrs = match self . override_connect_addr {
400
416
Some ( override_connect_addr) => vec ! [ override_connect_addr] ,
401
417
None => {
@@ -430,22 +446,27 @@ impl ConnectOptions {
430
446
return Err ( ErrorCode :: DestinationIpProhibited ) ;
431
447
}
432
448
433
- timeout ( self . connect_timeout , TcpStream :: connect ( & * socket_addrs) )
449
+ let stream = timeout ( self . connect_timeout , TcpStream :: connect ( & * socket_addrs) )
434
450
. await
435
451
. map_err ( |_| ErrorCode :: ConnectionTimeout ) ?
436
452
. map_err ( |err| match err. kind ( ) {
437
453
std:: io:: ErrorKind :: AddrNotAvailable => {
438
454
dns_error ( "address not available" . into ( ) , 0 )
439
455
}
440
456
_ => ErrorCode :: ConnectionRefused ,
441
- } )
457
+ } ) ?;
458
+ Ok ( PermittedTcpStream {
459
+ inner : stream,
460
+ _permit : self . permit . clone ( ) ,
461
+ } )
442
462
}
443
463
464
+ /// Establish a TLS connection to the given URI and default port.
444
465
async fn connect_tls (
445
466
& self ,
446
467
uri : & Uri ,
447
468
default_port : u16 ,
448
- ) -> Result < TlsStream < TcpStream > , ErrorCode > {
469
+ ) -> Result < TlsStream < PermittedTcpStream > , ErrorCode > {
449
470
let tcp_stream = self . connect_tcp ( uri, default_port) . await ?;
450
471
451
472
let mut tls_client_config = self . tls_client_config . as_deref ( ) . unwrap ( ) . clone ( ) ;
@@ -455,7 +476,7 @@ impl ConnectOptions {
455
476
let domain = rustls:: pki_types:: ServerName :: try_from ( uri. host ( ) . unwrap ( ) )
456
477
. map_err ( |e| {
457
478
tracing:: warn!( "dns lookup error: {e:?}" ) ;
458
- dns_error ( "invalid dns name" . to_string ( ) , 0 )
479
+ dns_error ( "invalid dns name" . into ( ) , 0 )
459
480
} ) ?
460
481
. to_owned ( ) ;
461
482
connector. connect ( domain, tcp_stream) . await . map_err ( |e| {
@@ -465,20 +486,22 @@ impl ConnectOptions {
465
486
}
466
487
}
467
488
489
+ /// A connector the uses `ConnectOptions`
468
490
#[ derive( Clone ) ]
469
491
struct HttpConnector ;
470
492
471
493
impl HttpConnector {
472
- async fn connect ( uri : Uri ) -> Result < TokioIo < TcpStream > , ErrorCode > {
494
+ async fn connect ( uri : Uri ) -> Result < TokioIo < PermittedTcpStream > , ErrorCode > {
473
495
let stream = CONNECT_OPTIONS . get ( ) . connect_tcp ( & uri, 80 ) . await ?;
474
496
Ok ( TokioIo :: new ( stream) )
475
497
}
476
498
}
477
499
478
500
impl Service < Uri > for HttpConnector {
479
- type Response = TokioIo < TcpStream > ;
501
+ type Response = TokioIo < PermittedTcpStream > ;
480
502
type Error = ErrorCode ;
481
- type Future = Pin < Box < dyn Future < Output = Result < TokioIo < TcpStream > , ErrorCode > > + Send > > ;
503
+ type Future =
504
+ Pin < Box < dyn Future < Output = Result < TokioIo < PermittedTcpStream > , ErrorCode > > + Send > > ;
482
505
483
506
fn poll_ready ( & mut self , _cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , Self :: Error > > {
484
507
Poll :: Ready ( Ok ( ( ) ) )
@@ -489,6 +512,7 @@ impl Service<Uri> for HttpConnector {
489
512
}
490
513
}
491
514
515
+ /// A connector that establishes TLS connections using `rustls` and `ConnectOptions`.
492
516
#[ derive( Clone ) ]
493
517
struct HttpsConnector ;
494
518
@@ -513,7 +537,7 @@ impl Service<Uri> for HttpsConnector {
513
537
}
514
538
}
515
539
516
- struct RustlsStream ( TlsStream < TcpStream > ) ;
540
+ struct RustlsStream ( TlsStream < PermittedTcpStream > ) ;
517
541
518
542
impl Connection for RustlsStream {
519
543
fn connected ( & self ) -> Connected {
@@ -568,6 +592,54 @@ impl AsyncWrite for RustlsStream {
568
592
}
569
593
}
570
594
595
+ /// A TCP stream that holds an optional permit indicating that it is allowed to exist.
596
+ struct PermittedTcpStream {
597
+ /// The wrapped TCP stream.
598
+ inner : TcpStream ,
599
+ /// A permit indicating that this stream is allowed to exist.
600
+ ///
601
+ /// When this stream is dropped, the permit is also dropped, allowing another
602
+ /// connection to be established.
603
+ _permit : Option < Arc < OwnedSemaphorePermit > > ,
604
+ }
605
+
606
+ impl Connection for PermittedTcpStream {
607
+ fn connected ( & self ) -> Connected {
608
+ self . inner . connected ( )
609
+ }
610
+ }
611
+
612
+ impl AsyncRead for PermittedTcpStream {
613
+ fn poll_read (
614
+ self : Pin < & mut Self > ,
615
+ cx : & mut Context < ' _ > ,
616
+ buf : & mut ReadBuf < ' _ > ,
617
+ ) -> Poll < std:: io:: Result < ( ) > > {
618
+ Pin :: new ( & mut self . get_mut ( ) . inner ) . poll_read ( cx, buf)
619
+ }
620
+ }
621
+
622
+ impl AsyncWrite for PermittedTcpStream {
623
+ fn poll_write (
624
+ self : Pin < & mut Self > ,
625
+ cx : & mut Context < ' _ > ,
626
+ buf : & [ u8 ] ,
627
+ ) -> Poll < Result < usize , std:: io:: Error > > {
628
+ Pin :: new ( & mut self . get_mut ( ) . inner ) . poll_write ( cx, buf)
629
+ }
630
+
631
+ fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , std:: io:: Error > > {
632
+ Pin :: new ( & mut self . get_mut ( ) . inner ) . poll_flush ( cx)
633
+ }
634
+
635
+ fn poll_shutdown (
636
+ self : Pin < & mut Self > ,
637
+ cx : & mut Context < ' _ > ,
638
+ ) -> Poll < Result < ( ) , std:: io:: Error > > {
639
+ Pin :: new ( & mut self . get_mut ( ) . inner ) . poll_shutdown ( cx)
640
+ }
641
+ }
642
+
571
643
/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request.
572
644
fn hyper_request_error ( err : hyper:: Error ) -> ErrorCode {
573
645
// If there's a source, we might be able to extract a wasi-http error from it.
0 commit comments