@@ -172,11 +172,9 @@ impl RequestSender {
172172 }
173173 span. record ( "url.full" , uri. to_string ( ) ) ;
174174
175+ // If the current span has opentelemetry trace context, inject it into the request
175176 spin_telemetry:: inject_trace_context ( & mut request) ;
176177
177- let host = request. uri ( ) . host ( ) . unwrap_or_default ( ) ;
178- let tls_client_config = self . component_tls_configs . get_client_config ( host) . clone ( ) ;
179-
180178 let is_self_request = request
181179 . uri ( )
182180 . authority ( )
@@ -243,34 +241,37 @@ impl RequestSender {
243241 }
244242 }
245243
244+ // Backfill span fields after potentially updating the URL in the interceptor
246245 let authority = request. uri ( ) . authority ( ) . context ( "authority not set" ) ?;
247246 span. record ( "server.address" , authority. host ( ) ) ;
248247 if let Some ( port) = authority. port ( ) {
249248 span. record ( "server.port" , port. as_u16 ( ) ) ;
250249 }
251250
251+ let tls_client_config = if use_tls {
252+ let host = request. uri ( ) . host ( ) . unwrap_or_default ( ) ;
253+ Some ( self . component_tls_configs . get_client_config ( host) . clone ( ) )
254+ } else {
255+ None
256+ } ;
257+
252258 let resp = CONNECT_OPTIONS . scope (
253259 ConnectOptions {
254260 blocked_networks : self . blocked_networks ,
255261 connect_timeout,
262+ tls_client_config,
256263 } ,
257264 async move {
258265 if use_tls {
259- TLS_CLIENT_CONFIG
260- . scope ( tls_client_config, async move {
261- self . http_clients . https . request ( request) . await
262- } )
263- . await
266+ self . http_clients . https . request ( request) . await
264267 } else {
265- let use_http2 = std:: env:: var_os ( "SPIN_OUTBOUND_H2C_PRIOR_KNOWLEDGE" )
266- . is_some_and ( |v| {
267- request
268- . uri ( )
269- . authority ( )
270- . is_some_and ( |authority| authority. as_str ( ) == v)
271- } ) ;
272-
273- if use_http2 {
268+ // For development purposes, allow configuring plaintext HTTP/2 for a specific host.
269+ let h2c_prior_knowledge_host =
270+ std:: env:: var ( "SPIN_OUTBOUND_H2C_PRIOR_KNOWLEDGE" ) . ok ( ) ;
271+ let use_h2c = h2c_prior_knowledge_host. as_deref ( )
272+ == request. uri ( ) . authority ( ) . map ( |a| a. as_str ( ) ) ;
273+
274+ if use_h2c {
274275 self . http_clients . http2 . request ( request) . await
275276 } else {
276277 self . http_clients . http1 . request ( request) . await
@@ -325,73 +326,84 @@ impl HttpClients {
325326 }
326327}
327328
328- #[ derive( Clone ) ]
329- struct ConnectOptions {
330- blocked_networks : BlockedNetworks ,
331- connect_timeout : Duration ,
332- }
333-
334329// We must use task-local variables for these config options when using
335330// `hyper_util::client::legacy::Client::request` because there's no way to plumb
336331// them through as parameters. Moreover, if there's already a pooled connection
337332// ready, we'll reuse that and ignore these options anyway.
338333tokio:: task_local! {
339334 static CONNECT_OPTIONS : ConnectOptions ;
340- static TLS_CLIENT_CONFIG : TlsClientConfig ;
341335}
342336
343- async fn connect_tcp ( uri : Uri , default_port : u16 ) -> Result < ( TcpStream , String ) , ErrorCode > {
344- let authority_str = if let Some ( authority) = uri. authority ( ) {
345- if authority. port ( ) . is_some ( ) {
346- authority. to_string ( )
347- } else {
348- format ! ( "{authority}:{default_port}" )
337+ #[ derive( Clone ) ]
338+ struct ConnectOptions {
339+ blocked_networks : BlockedNetworks ,
340+ connect_timeout : Duration ,
341+ tls_client_config : Option < TlsClientConfig > ,
342+ }
343+
344+ impl ConnectOptions {
345+ async fn connect_tcp ( & self , uri : & Uri , default_port : u16 ) -> Result < TcpStream , ErrorCode > {
346+ let host = uri. host ( ) . ok_or ( ErrorCode :: HttpRequestUriInvalid ) ?;
347+ let host_and_port = ( host, uri. port_u16 ( ) . unwrap_or ( default_port) ) ;
348+
349+ let mut socket_addrs = tokio:: net:: lookup_host ( host_and_port)
350+ . await
351+ . map_err ( |_| dns_error ( "address not available" . into ( ) , 0 ) ) ?
352+ . collect :: < Vec < _ > > ( ) ;
353+
354+ // Remove blocked IPs
355+ let blocked_addrs = self . blocked_networks . remove_blocked ( & mut socket_addrs) ;
356+ if socket_addrs. is_empty ( ) && !blocked_addrs. is_empty ( ) {
357+ tracing:: error!(
358+ "error.type" = "destination_ip_prohibited" ,
359+ ?blocked_addrs,
360+ "all destination IP(s) prohibited by runtime config"
361+ ) ;
362+ return Err ( ErrorCode :: DestinationIpProhibited ) ;
349363 }
350- } else {
351- return Err ( ErrorCode :: HttpRequestUriInvalid ) ;
352- } ;
353-
354- let ConnectOptions {
355- blocked_networks,
356- connect_timeout,
357- } = CONNECT_OPTIONS . get ( ) ;
358-
359- let mut socket_addrs = tokio:: net:: lookup_host ( & authority_str)
360- . await
361- . map_err ( |_| dns_error ( "address not available" . into ( ) , 0 ) ) ?
362- . collect :: < Vec < _ > > ( ) ;
363-
364- // Remove blocked IPs
365- let blocked_addrs = blocked_networks. remove_blocked ( & mut socket_addrs) ;
366- if socket_addrs. is_empty ( ) && !blocked_addrs. is_empty ( ) {
367- tracing:: error!(
368- "error.type" = "destination_ip_prohibited" ,
369- ?blocked_addrs,
370- "all destination IP(s) prohibited by runtime config"
371- ) ;
372- return Err ( ErrorCode :: DestinationIpProhibited ) ;
373- }
374364
375- Ok ( (
376- timeout ( connect_timeout, TcpStream :: connect ( socket_addrs. as_slice ( ) ) )
365+ timeout ( self . connect_timeout , TcpStream :: connect ( & * socket_addrs) )
377366 . await
378367 . map_err ( |_| ErrorCode :: ConnectionTimeout ) ?
379368 . map_err ( |err| match err. kind ( ) {
380369 std:: io:: ErrorKind :: AddrNotAvailable => {
381370 dns_error ( "address not available" . into ( ) , 0 )
382371 }
383372 _ => ErrorCode :: ConnectionRefused ,
384- } ) ?,
385- authority_str,
386- ) )
373+ } )
374+ }
375+
376+ async fn connect_tls (
377+ & self ,
378+ uri : & Uri ,
379+ default_port : u16 ,
380+ ) -> Result < TlsStream < TcpStream > , ErrorCode > {
381+ let tcp_stream = self . connect_tcp ( uri, default_port) . await ?;
382+
383+ let mut tls_client_config = self . tls_client_config . as_deref ( ) . unwrap ( ) . clone ( ) ;
384+ tls_client_config. alpn_protocols = vec ! [ b"h2" . to_vec( ) , b"http/1.1" . to_vec( ) ] ;
385+
386+ let connector = tokio_rustls:: TlsConnector :: from ( Arc :: new ( tls_client_config) ) ;
387+ let domain = rustls:: pki_types:: ServerName :: try_from ( uri. host ( ) . unwrap ( ) )
388+ . map_err ( |e| {
389+ tracing:: warn!( "dns lookup error: {e:?}" ) ;
390+ dns_error ( "invalid dns name" . to_string ( ) , 0 )
391+ } ) ?
392+ . to_owned ( ) ;
393+ connector. connect ( domain, tcp_stream) . await . map_err ( |e| {
394+ tracing:: warn!( "tls protocol error: {e:?}" ) ;
395+ ErrorCode :: TlsProtocolError
396+ } )
397+ }
387398}
388399
389400#[ derive( Clone ) ]
390401struct HttpConnector ;
391402
392403impl HttpConnector {
393404 async fn connect ( uri : Uri ) -> Result < TokioIo < TcpStream > , ErrorCode > {
394- Ok ( TokioIo :: new ( connect_tcp ( uri, 80 ) . await ?. 0 ) )
405+ let stream = CONNECT_OPTIONS . get ( ) . connect_tcp ( & uri, 80 ) . await ?;
406+ Ok ( TokioIo :: new ( stream) )
395407 }
396408}
397409
@@ -414,27 +426,7 @@ struct HttpsConnector;
414426
415427impl HttpsConnector {
416428 async fn connect ( uri : Uri ) -> Result < TokioIo < RustlsStream > , ErrorCode > {
417- use rustls:: pki_types:: ServerName ;
418-
419- let ( tcp_stream, authority_str) = connect_tcp ( uri, 443 ) . await ?;
420-
421- let mut tls_client_config = ( * TLS_CLIENT_CONFIG . get ( ) ) . clone ( ) ;
422- tls_client_config. alpn_protocols = vec ! [ b"h2" . to_vec( ) , b"http/1.1" . to_vec( ) ] ;
423-
424- let connector = tokio_rustls:: TlsConnector :: from ( Arc :: new ( tls_client_config) ) ;
425- let mut parts = authority_str. split ( ':' ) ;
426- let host = parts. next ( ) . unwrap_or ( & authority_str) ;
427- let domain = ServerName :: try_from ( host)
428- . map_err ( |e| {
429- tracing:: warn!( "dns lookup error: {e:?}" ) ;
430- dns_error ( "invalid dns name" . to_string ( ) , 0 )
431- } ) ?
432- . to_owned ( ) ;
433- let stream = connector. connect ( domain, tcp_stream) . await . map_err ( |e| {
434- tracing:: warn!( "tls protocol error: {e:?}" ) ;
435- ErrorCode :: TlsProtocolError
436- } ) ?;
437-
429+ let stream = CONNECT_OPTIONS . get ( ) . connect_tls ( & uri, 443 ) . await ?;
438430 Ok ( TokioIo :: new ( RustlsStream ( stream) ) )
439431 }
440432}
0 commit comments