1- use std:: { error:: Error , sync:: Arc } ;
1+ use std:: { error:: Error , future :: Future , pin :: Pin , sync:: Arc } ;
22
33use anyhow:: Context ;
4+ use bytes:: Bytes ;
45use http:: { header:: HOST , Request } ;
5- use http_body_util:: BodyExt ;
6+ use http_body_util:: { combinators:: BoxBody , BodyExt } ;
7+ use hyper_util:: rt:: TokioExecutor ;
68use spin_factor_outbound_networking:: {
79 config:: { allowed_hosts:: OutboundAllowedHosts , blocked_networks:: BlockedNetworks } ,
810 ComponentTlsClientConfigs , TlsClientConfig ,
@@ -259,7 +261,7 @@ async fn send_request_handler(
259261 _ => ErrorCode :: ConnectionRefused ,
260262 } ) ?;
261263
262- let ( mut sender, worker) = if use_tls {
264+ let ( mut sender, worker, is_http2 ) = if use_tls {
263265 #[ cfg( any( target_arch = "riscv64" , target_arch = "s390x" ) ) ]
264266 {
265267 return Err ( ErrorCode :: InternalError ( Some (
@@ -270,7 +272,11 @@ async fn send_request_handler(
270272 #[ cfg( not( any( target_arch = "riscv64" , target_arch = "s390x" ) ) ) ]
271273 {
272274 use rustls:: pki_types:: ServerName ;
273- let connector = tokio_rustls:: TlsConnector :: from ( tls_client_config. inner ( ) ) ;
275+
276+ let mut tls_client_config = ( * tls_client_config) . clone ( ) ;
277+ tls_client_config. alpn_protocols = vec ! [ b"h2" . to_vec( ) , b"http/1.1" . to_vec( ) ] ;
278+
279+ let connector = tokio_rustls:: TlsConnector :: from ( Arc :: new ( tls_client_config) ) ;
274280 let mut parts = authority_str. split ( ':' ) ;
275281 let host = parts. next ( ) . unwrap_or ( & authority_str) ;
276282 let domain = ServerName :: try_from ( host)
@@ -283,15 +289,30 @@ async fn send_request_handler(
283289 tracing:: warn!( "tls protocol error: {e:?}" ) ;
284290 ErrorCode :: TlsProtocolError
285291 } ) ?;
292+
293+ let is_http2 = stream. get_ref ( ) . 1 . alpn_protocol ( ) == Some ( b"h2" ) ;
294+
286295 let stream = TokioIo :: new ( stream) ;
287296
288- let ( sender, conn) = timeout (
289- connect_timeout,
290- hyper:: client:: conn:: http1:: handshake ( stream) ,
291- )
292- . await
293- . map_err ( |_| ErrorCode :: ConnectionTimeout ) ?
294- . map_err ( hyper_request_error) ?;
297+ let ( sender, conn) = if is_http2 {
298+ timeout (
299+ connect_timeout,
300+ hyper:: client:: conn:: http2:: handshake ( TokioExecutor :: default ( ) , stream) ,
301+ )
302+ . await
303+ . map_err ( |_| ErrorCode :: ConnectionTimeout ) ?
304+ . map_err ( hyper_request_error)
305+ . map ( |( sender, conn) | ( HttpSender :: Http2 ( sender) , HttpConn :: Http2 ( conn) ) ) ?
306+ } else {
307+ timeout (
308+ connect_timeout,
309+ hyper:: client:: conn:: http1:: handshake ( stream) ,
310+ )
311+ . await
312+ . map_err ( |_| ErrorCode :: ConnectionTimeout ) ?
313+ . map_err ( hyper_request_error)
314+ . map ( |( sender, conn) | ( HttpSender :: Http1 ( sender) , HttpConn :: Http1 ( conn) ) ) ?
315+ } ;
295316
296317 let worker = wasmtime_wasi:: runtime:: spawn ( async move {
297318 match conn. await {
@@ -302,18 +323,37 @@ async fn send_request_handler(
302323 }
303324 } ) ;
304325
305- ( sender, worker)
326+ ( sender, worker, is_http2 )
306327 }
307328 } else {
308329 let tcp_stream = TokioIo :: new ( tcp_stream) ;
309- let ( sender, conn) = timeout (
310- connect_timeout,
311- // TODO: we should plumb the builder through the http context, and use it here
312- hyper:: client:: conn:: http1:: handshake ( tcp_stream) ,
313- )
314- . await
315- . map_err ( |_| ErrorCode :: ConnectionTimeout ) ?
316- . map_err ( hyper_request_error) ?;
330+
331+ let is_http2 = std:: env:: var_os ( "SPIN_OUTBOUND_H2C_PRIOR_KNOWLEDGE" ) . is_some_and ( |v| {
332+ request
333+ . uri ( )
334+ . authority ( )
335+ . is_some_and ( |authority| authority. as_str ( ) == v)
336+ } ) ;
337+
338+ let ( sender, conn) = if is_http2 {
339+ timeout (
340+ connect_timeout,
341+ hyper:: client:: conn:: http2:: handshake ( TokioExecutor :: default ( ) , tcp_stream) ,
342+ )
343+ . await
344+ . map_err ( |_| ErrorCode :: ConnectionTimeout ) ?
345+ . map_err ( hyper_request_error)
346+ . map ( |( sender, conn) | ( HttpSender :: Http2 ( sender) , HttpConn :: Http2 ( conn) ) ) ?
347+ } else {
348+ timeout (
349+ connect_timeout,
350+ hyper:: client:: conn:: http1:: handshake ( tcp_stream) ,
351+ )
352+ . await
353+ . map_err ( |_| ErrorCode :: ConnectionTimeout ) ?
354+ . map_err ( hyper_request_error)
355+ . map ( |( sender, conn) | ( HttpSender :: Http1 ( sender) , HttpConn :: Http1 ( conn) ) ) ?
356+ } ;
317357
318358 let worker = wasmtime_wasi:: runtime:: spawn ( async move {
319359 match conn. await {
@@ -323,22 +363,24 @@ async fn send_request_handler(
323363 }
324364 } ) ;
325365
326- ( sender, worker)
366+ ( sender, worker, is_http2 )
327367 } ;
328368
329- // at this point, the request contains the scheme and the authority, but
330- // the http packet should only include those if addressing a proxy, so
331- // remove them here, since SendRequest::send_request does not do it for us
332- * request. uri_mut ( ) = http:: Uri :: builder ( )
333- . path_and_query (
334- request
335- . uri ( )
336- . path_and_query ( )
337- . map ( |p| p. as_str ( ) )
338- . unwrap_or ( "/" ) ,
339- )
340- . build ( )
341- . expect ( "comes from valid request" ) ;
369+ if !is_http2 {
370+ // at this point, the request contains the scheme and the authority, but
371+ // the http packet should only include those if addressing a proxy, so
372+ // remove them here, since SendRequest::send_request does not do it for us
373+ * request. uri_mut ( ) = http:: Uri :: builder ( )
374+ . path_and_query (
375+ request
376+ . uri ( )
377+ . path_and_query ( )
378+ . map ( |p| p. as_str ( ) )
379+ . unwrap_or ( "/" ) ,
380+ )
381+ . build ( )
382+ . expect ( "comes from valid request" ) ;
383+ }
342384
343385 let resp = timeout ( first_byte_timeout, sender. send_request ( request) )
344386 . await
@@ -355,6 +397,43 @@ async fn send_request_handler(
355397 } )
356398}
357399
400+ enum HttpSender {
401+ Http1 ( hyper:: client:: conn:: http1:: SendRequest < BoxBody < Bytes , ErrorCode > > ) ,
402+ Http2 ( hyper:: client:: conn:: http2:: SendRequest < BoxBody < Bytes , ErrorCode > > ) ,
403+ }
404+
405+ #[ allow( clippy:: large_enum_variant) ]
406+ enum HttpConn < T : hyper:: rt:: Read + hyper:: rt:: Write + Unpin + Send + ' static > {
407+ Http1 ( hyper:: client:: conn:: http1:: Connection < T , BoxBody < Bytes , ErrorCode > > ) ,
408+ Http2 ( hyper:: client:: conn:: http2:: Connection < T , BoxBody < Bytes , ErrorCode > , TokioExecutor > ) ,
409+ }
410+
411+ impl < T : hyper:: rt:: Read + hyper:: rt:: Write + Unpin + Send > Future for HttpConn < T > {
412+ type Output = Result < ( ) , hyper:: Error > ;
413+
414+ fn poll (
415+ self : Pin < & mut Self > ,
416+ cx : & mut std:: task:: Context < ' _ > ,
417+ ) -> std:: task:: Poll < Self :: Output > {
418+ match self . get_mut ( ) {
419+ HttpConn :: Http1 ( conn) => Pin :: new ( conn) . poll ( cx) ,
420+ HttpConn :: Http2 ( conn) => Pin :: new ( conn) . poll ( cx) ,
421+ }
422+ }
423+ }
424+
425+ impl HttpSender {
426+ async fn send_request (
427+ & mut self ,
428+ request : http:: Request < BoxBody < Bytes , ErrorCode > > ,
429+ ) -> Result < http:: Response < hyper:: body:: Incoming > , hyper:: Error > {
430+ match self {
431+ HttpSender :: Http1 ( sender) => sender. send_request ( request) . await ,
432+ HttpSender :: Http2 ( sender) => sender. send_request ( request) . await ,
433+ }
434+ }
435+ }
436+
358437/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request.
359438fn hyper_request_error ( err : hyper:: Error ) -> ErrorCode {
360439 // If there's a source, we might be able to extract a wasi-http error from it.
0 commit comments