@@ -172,11 +172,9 @@ impl RequestSender {
172
172
}
173
173
span. record ( "url.full" , uri. to_string ( ) ) ;
174
174
175
+ // If the current span has opentelemetry trace context, inject it into the request
175
176
spin_telemetry:: inject_trace_context ( & mut request) ;
176
177
177
- let host = request. uri ( ) . host ( ) . unwrap_or_default ( ) ;
178
- let tls_client_config = self . component_tls_configs . get_client_config ( host) . clone ( ) ;
179
-
180
178
let is_self_request = request
181
179
. uri ( )
182
180
. authority ( )
@@ -243,34 +241,37 @@ impl RequestSender {
243
241
}
244
242
}
245
243
244
+ // Backfill span fields after potentially updating the URL in the interceptor
246
245
let authority = request. uri ( ) . authority ( ) . context ( "authority not set" ) ?;
247
246
span. record ( "server.address" , authority. host ( ) ) ;
248
247
if let Some ( port) = authority. port ( ) {
249
248
span. record ( "server.port" , port. as_u16 ( ) ) ;
250
249
}
251
250
251
+ let tls_client_config = if use_tls {
252
+ let host = request. uri ( ) . host ( ) . unwrap_or_default ( ) ;
253
+ Some ( self . component_tls_configs . get_client_config ( host) . clone ( ) )
254
+ } else {
255
+ None
256
+ } ;
257
+
252
258
let resp = CONNECT_OPTIONS . scope (
253
259
ConnectOptions {
254
260
blocked_networks : self . blocked_networks ,
255
261
connect_timeout,
262
+ tls_client_config,
256
263
} ,
257
264
async move {
258
265
if use_tls {
259
- TLS_CLIENT_CONFIG
260
- . scope ( tls_client_config, async move {
261
- self . http_clients . https . request ( request) . await
262
- } )
263
- . await
266
+ self . http_clients . https . request ( request) . await
264
267
} else {
265
- let use_http2 = std:: env:: var_os ( "SPIN_OUTBOUND_H2C_PRIOR_KNOWLEDGE" )
266
- . is_some_and ( |v| {
267
- request
268
- . uri ( )
269
- . authority ( )
270
- . is_some_and ( |authority| authority. as_str ( ) == v)
271
- } ) ;
272
-
273
- if use_http2 {
268
+ // For development purposes, allow configuring plaintext HTTP/2 for a specific host.
269
+ let h2c_prior_knowledge_host =
270
+ std:: env:: var ( "SPIN_OUTBOUND_H2C_PRIOR_KNOWLEDGE" ) . ok ( ) ;
271
+ let use_h2c = h2c_prior_knowledge_host. as_deref ( )
272
+ == request. uri ( ) . authority ( ) . map ( |a| a. as_str ( ) ) ;
273
+
274
+ if use_h2c {
274
275
self . http_clients . http2 . request ( request) . await
275
276
} else {
276
277
self . http_clients . http1 . request ( request) . await
@@ -325,73 +326,84 @@ impl HttpClients {
325
326
}
326
327
}
327
328
328
- #[ derive( Clone ) ]
329
- struct ConnectOptions {
330
- blocked_networks : BlockedNetworks ,
331
- connect_timeout : Duration ,
332
- }
333
-
334
329
// We must use task-local variables for these config options when using
335
330
// `hyper_util::client::legacy::Client::request` because there's no way to plumb
336
331
// them through as parameters. Moreover, if there's already a pooled connection
337
332
// ready, we'll reuse that and ignore these options anyway.
338
333
tokio:: task_local! {
339
334
static CONNECT_OPTIONS : ConnectOptions ;
340
- static TLS_CLIENT_CONFIG : TlsClientConfig ;
341
335
}
342
336
343
- async fn connect_tcp ( uri : Uri , default_port : u16 ) -> Result < ( TcpStream , String ) , ErrorCode > {
344
- let authority_str = if let Some ( authority) = uri. authority ( ) {
345
- if authority. port ( ) . is_some ( ) {
346
- authority. to_string ( )
347
- } else {
348
- format ! ( "{authority}:{default_port}" )
337
+ #[ derive( Clone ) ]
338
+ struct ConnectOptions {
339
+ blocked_networks : BlockedNetworks ,
340
+ connect_timeout : Duration ,
341
+ tls_client_config : Option < TlsClientConfig > ,
342
+ }
343
+
344
+ impl ConnectOptions {
345
+ async fn connect_tcp ( & self , uri : & Uri , default_port : u16 ) -> Result < TcpStream , ErrorCode > {
346
+ let host = uri. host ( ) . ok_or ( ErrorCode :: HttpRequestUriInvalid ) ?;
347
+ let host_and_port = ( host, uri. port_u16 ( ) . unwrap_or ( default_port) ) ;
348
+
349
+ let mut socket_addrs = tokio:: net:: lookup_host ( host_and_port)
350
+ . await
351
+ . map_err ( |_| dns_error ( "address not available" . into ( ) , 0 ) ) ?
352
+ . collect :: < Vec < _ > > ( ) ;
353
+
354
+ // Remove blocked IPs
355
+ let blocked_addrs = self . blocked_networks . remove_blocked ( & mut socket_addrs) ;
356
+ if socket_addrs. is_empty ( ) && !blocked_addrs. is_empty ( ) {
357
+ tracing:: error!(
358
+ "error.type" = "destination_ip_prohibited" ,
359
+ ?blocked_addrs,
360
+ "all destination IP(s) prohibited by runtime config"
361
+ ) ;
362
+ return Err ( ErrorCode :: DestinationIpProhibited ) ;
349
363
}
350
- } else {
351
- return Err ( ErrorCode :: HttpRequestUriInvalid ) ;
352
- } ;
353
-
354
- let ConnectOptions {
355
- blocked_networks,
356
- connect_timeout,
357
- } = CONNECT_OPTIONS . get ( ) ;
358
-
359
- let mut socket_addrs = tokio:: net:: lookup_host ( & authority_str)
360
- . await
361
- . map_err ( |_| dns_error ( "address not available" . into ( ) , 0 ) ) ?
362
- . collect :: < Vec < _ > > ( ) ;
363
-
364
- // Remove blocked IPs
365
- let blocked_addrs = blocked_networks. remove_blocked ( & mut socket_addrs) ;
366
- if socket_addrs. is_empty ( ) && !blocked_addrs. is_empty ( ) {
367
- tracing:: error!(
368
- "error.type" = "destination_ip_prohibited" ,
369
- ?blocked_addrs,
370
- "all destination IP(s) prohibited by runtime config"
371
- ) ;
372
- return Err ( ErrorCode :: DestinationIpProhibited ) ;
373
- }
374
364
375
- Ok ( (
376
- timeout ( connect_timeout, TcpStream :: connect ( socket_addrs. as_slice ( ) ) )
365
+ timeout ( self . connect_timeout , TcpStream :: connect ( & * socket_addrs) )
377
366
. await
378
367
. map_err ( |_| ErrorCode :: ConnectionTimeout ) ?
379
368
. map_err ( |err| match err. kind ( ) {
380
369
std:: io:: ErrorKind :: AddrNotAvailable => {
381
370
dns_error ( "address not available" . into ( ) , 0 )
382
371
}
383
372
_ => ErrorCode :: ConnectionRefused ,
384
- } ) ?,
385
- authority_str,
386
- ) )
373
+ } )
374
+ }
375
+
376
+ async fn connect_tls (
377
+ & self ,
378
+ uri : & Uri ,
379
+ default_port : u16 ,
380
+ ) -> Result < TlsStream < TcpStream > , ErrorCode > {
381
+ let tcp_stream = self . connect_tcp ( uri, default_port) . await ?;
382
+
383
+ let mut tls_client_config = self . tls_client_config . as_deref ( ) . unwrap ( ) . clone ( ) ;
384
+ tls_client_config. alpn_protocols = vec ! [ b"h2" . to_vec( ) , b"http/1.1" . to_vec( ) ] ;
385
+
386
+ let connector = tokio_rustls:: TlsConnector :: from ( Arc :: new ( tls_client_config) ) ;
387
+ let domain = rustls:: pki_types:: ServerName :: try_from ( uri. host ( ) . unwrap ( ) )
388
+ . map_err ( |e| {
389
+ tracing:: warn!( "dns lookup error: {e:?}" ) ;
390
+ dns_error ( "invalid dns name" . to_string ( ) , 0 )
391
+ } ) ?
392
+ . to_owned ( ) ;
393
+ connector. connect ( domain, tcp_stream) . await . map_err ( |e| {
394
+ tracing:: warn!( "tls protocol error: {e:?}" ) ;
395
+ ErrorCode :: TlsProtocolError
396
+ } )
397
+ }
387
398
}
388
399
389
400
#[ derive( Clone ) ]
390
401
struct HttpConnector ;
391
402
392
403
impl HttpConnector {
393
404
async fn connect ( uri : Uri ) -> Result < TokioIo < TcpStream > , ErrorCode > {
394
- Ok ( TokioIo :: new ( connect_tcp ( uri, 80 ) . await ?. 0 ) )
405
+ let stream = CONNECT_OPTIONS . get ( ) . connect_tcp ( & uri, 80 ) . await ?;
406
+ Ok ( TokioIo :: new ( stream) )
395
407
}
396
408
}
397
409
@@ -414,27 +426,7 @@ struct HttpsConnector;
414
426
415
427
impl HttpsConnector {
416
428
async fn connect ( uri : Uri ) -> Result < TokioIo < RustlsStream > , ErrorCode > {
417
- use rustls:: pki_types:: ServerName ;
418
-
419
- let ( tcp_stream, authority_str) = connect_tcp ( uri, 443 ) . await ?;
420
-
421
- let mut tls_client_config = ( * TLS_CLIENT_CONFIG . get ( ) ) . clone ( ) ;
422
- tls_client_config. alpn_protocols = vec ! [ b"h2" . to_vec( ) , b"http/1.1" . to_vec( ) ] ;
423
-
424
- let connector = tokio_rustls:: TlsConnector :: from ( Arc :: new ( tls_client_config) ) ;
425
- let mut parts = authority_str. split ( ':' ) ;
426
- let host = parts. next ( ) . unwrap_or ( & authority_str) ;
427
- let domain = ServerName :: try_from ( host)
428
- . map_err ( |e| {
429
- tracing:: warn!( "dns lookup error: {e:?}" ) ;
430
- dns_error ( "invalid dns name" . to_string ( ) , 0 )
431
- } ) ?
432
- . to_owned ( ) ;
433
- let stream = connector. connect ( domain, tcp_stream) . await . map_err ( |e| {
434
- tracing:: warn!( "tls protocol error: {e:?}" ) ;
435
- ErrorCode :: TlsProtocolError
436
- } ) ?;
437
-
429
+ let stream = CONNECT_OPTIONS . get ( ) . connect_tls ( & uri, 443 ) . await ?;
438
430
Ok ( TokioIo :: new ( RustlsStream ( stream) ) )
439
431
}
440
432
}
0 commit comments