diff --git a/crates/factor-outbound-http/src/intercept.rs b/crates/factor-outbound-http/src/intercept.rs index a818c8e91f..c1252e54f0 100644 --- a/crates/factor-outbound-http/src/intercept.rs +++ b/crates/factor-outbound-http/src/intercept.rs @@ -1,3 +1,5 @@ +use std::net::SocketAddr; + use http::{Request, Response}; use http_body_util::{BodyExt, Full}; use spin_world::async_trait; @@ -39,7 +41,7 @@ pub enum InterceptOutcome { pub struct InterceptRequest { inner: Request<()>, body: InterceptBody, - pub(crate) override_connect_host: Option, + pub(crate) override_connect_addr: Option, } enum InterceptBody { @@ -48,16 +50,17 @@ enum InterceptBody { } impl InterceptRequest { - /// Overrides the host that will be connected to for this outbound request. + /// Overrides the IP and port that will be connected to for this outbound + /// request. /// /// This override does not have any effect on TLS server name checking or /// HTTP authority / host headers. /// - /// This host will not be checked against `allowed_outbound_hosts`; if that - /// check should occur it must be performed by the interceptor. The resolved - /// IP addresses from this host will be checked against blocked IP networks. - pub fn override_connect_host(&mut self, host: impl Into) { - self.override_connect_host = Some(host.into()) + /// The IP will be checked against blocked IP networks but it will not be + /// checked against `allowed_outbound_hosts`; if that check needs to occur + /// it must be performed by the interceptor. + pub fn override_connect_addr(&mut self, endpoint: SocketAddr) { + self.override_connect_addr = Some(endpoint); } pub fn into_hyper_request(self) -> Request { @@ -94,7 +97,7 @@ impl From> for InterceptRequest { Self { inner: Request::from_parts(parts, ()), body: InterceptBody::Hyper(body), - override_connect_host: None, + override_connect_addr: None, } } } @@ -105,7 +108,7 @@ impl From>> for InterceptRequest { Self { inner: Request::from_parts(parts, ()), body: InterceptBody::Vec(body), - override_connect_host: None, + override_connect_addr: None, } } } diff --git a/crates/factor-outbound-http/src/wasi.rs b/crates/factor-outbound-http/src/wasi.rs index a28fd31ec0..756799adf3 100644 --- a/crates/factor-outbound-http/src/wasi.rs +++ b/crates/factor-outbound-http/src/wasi.rs @@ -2,6 +2,7 @@ use std::{ error::Error, future::Future, io::IoSlice, + net::SocketAddr, pin::Pin, sync::Arc, task::{Context, Poll}, @@ -113,7 +114,7 @@ impl WasiHttpView for WasiHttpImplInner<'_> { http.response.status_code = Empty, server.address = Empty, server.port = Empty, - ), + ) )] fn send_request( &mut self, @@ -166,12 +167,12 @@ impl RequestSender { spin_telemetry::inject_trace_context(&mut request); // Run any configured request interceptor - let mut override_connect_host = None; + let mut override_connect_addr = None; if let Some(interceptor) = &self.request_interceptor { let intercept_request = std::mem::take(&mut request).into(); match interceptor.intercept(intercept_request).await? { InterceptOutcome::Continue(mut req) => { - override_connect_host = req.override_connect_host.take(); + override_connect_addr = req.override_connect_addr.take(); request = req.into_hyper_request(); } InterceptOutcome::Complete(resp) => { @@ -186,17 +187,19 @@ impl RequestSender { } // Backfill span fields after potentially updating the URL in the interceptor - if let Some(authority) = request.uri().authority() { - let span = tracing::Span::current(); - let host = override_connect_host.as_deref().unwrap_or(authority.host()); - span.record("server.address", host); - if let Some(port) = authority.port() { - span.record("server.port", port.as_u16()); + let span = tracing::Span::current(); + if let Some(addr) = override_connect_addr { + span.record("server.address", addr.ip().to_string()); + span.record("server.port", addr.port()); + } else if let Some(authority) = request.uri().authority() { + span.record("server.address", authority.host()); + if let Some(port) = authority.port_u16() { + span.record("server.port", port); } } Ok(self - .send_request(request, config, override_connect_host) + .send_request(request, config, override_connect_addr) .await?) } @@ -275,7 +278,7 @@ impl RequestSender { self, request: OutgoingRequest, config: OutgoingRequestConfig, - override_connect_host: Option, + override_connect_addr: Option, ) -> Result { let OutgoingRequestConfig { use_tls, @@ -296,7 +299,7 @@ impl RequestSender { blocked_networks: self.blocked_networks, connect_timeout, tls_client_config, - override_connect_host, + override_connect_addr, }, async move { if use_tls { @@ -376,26 +379,33 @@ struct ConnectOptions { blocked_networks: BlockedNetworks, connect_timeout: Duration, tls_client_config: Option, - override_connect_host: Option, + override_connect_addr: Option, } impl ConnectOptions { async fn connect_tcp(&self, uri: &Uri, default_port: u16) -> Result { - let host = self - .override_connect_host - .as_deref() - .or(uri.host()) - .ok_or(ErrorCode::HttpRequestUriInvalid)?; - let host_and_port = (host, uri.port_u16().unwrap_or(default_port)); - - let mut socket_addrs = tokio::net::lookup_host(host_and_port) - .await - .map_err(|err| { - tracing::debug!(?host_and_port, ?err, "Error resolving host"); - dns_error("address not available".into(), 0) - })? - .collect::>(); - tracing::debug!(?host_and_port, ?socket_addrs, "Resolved host"); + let mut socket_addrs = match self.override_connect_addr { + Some(override_connect_addr) => vec![override_connect_addr], + None => { + let authority = uri.authority().ok_or(ErrorCode::HttpRequestUriInvalid)?; + + let host_and_port = if authority.port().is_some() { + authority.as_str().to_string() + } else { + format!("{}:{}", authority.as_str(), default_port) + }; + + let socket_addrs = tokio::net::lookup_host(&host_and_port) + .await + .map_err(|err| { + tracing::debug!(?host_and_port, ?err, "Error resolving host"); + dns_error("address not available".into(), 0) + })? + .collect::>(); + tracing::debug!(?host_and_port, ?socket_addrs, "Resolved host"); + socket_addrs + } + }; // Remove blocked IPs let blocked_addrs = self.blocked_networks.remove_blocked(&mut socket_addrs); diff --git a/crates/factor-outbound-http/tests/factor_test.rs b/crates/factor-outbound-http/tests/factor_test.rs index 44c4754e95..76601fd280 100644 --- a/crates/factor-outbound-http/tests/factor_test.rs +++ b/crates/factor-outbound-http/tests/factor_test.rs @@ -22,7 +22,7 @@ struct TestFactors { http: OutboundHttpFactor, } -#[tokio::test] +#[tokio::test(flavor = "multi_thread")] async fn allowed_host_is_allowed() -> anyhow::Result<()> { let mut state = test_instance_state("https://*", true).await?; let mut wasi_http = OutboundHttpFactor::get_wasi_http_impl(&mut state).unwrap(); @@ -36,7 +36,7 @@ async fn allowed_host_is_allowed() -> anyhow::Result<()> { Ok(()) } -#[tokio::test] +#[tokio::test(flavor = "multi_thread")] async fn self_request_smoke_test() -> anyhow::Result<()> { let mut state = test_instance_state("http://self", true).await?; // [100::] is the IPv6 "Discard Prefix", which should always fail @@ -52,7 +52,7 @@ async fn self_request_smoke_test() -> anyhow::Result<()> { Ok(()) } -#[tokio::test] +#[tokio::test(flavor = "multi_thread")] async fn disallowed_host_fails() -> anyhow::Result<()> { let mut state = test_instance_state("https://allowed.test", true).await?; let mut wasi_http = OutboundHttpFactor::get_wasi_http_impl(&mut state).unwrap(); @@ -67,7 +67,7 @@ async fn disallowed_host_fails() -> anyhow::Result<()> { Ok(()) } -#[tokio::test] +#[tokio::test(flavor = "multi_thread")] async fn disallowed_private_ips_fails() -> anyhow::Result<()> { async fn run_test(allow_private_ips: bool) -> anyhow::Result<()> { let mut state = test_instance_state("http://*", allow_private_ips).await?; @@ -100,8 +100,8 @@ async fn disallowed_private_ips_fails() -> anyhow::Result<()> { Ok(()) } -#[tokio::test] -async fn override_connect_host_disallowed_private_ip_fails() -> anyhow::Result<()> { +#[tokio::test(flavor = "multi_thread")] +async fn override_connect_addr_disallowed_private_ip_fails() -> anyhow::Result<()> { let mut state = test_instance_state("http://*", false).await?; state.http.set_request_interceptor({ struct Interceptor; @@ -111,7 +111,7 @@ async fn override_connect_host_disallowed_private_ip_fails() -> anyhow::Result<( &self, mut request: InterceptRequest, ) -> wasmtime_wasi_http::HttpResult { - request.override_connect_host("localhost"); + request.override_connect_addr("[::1]:80".parse().unwrap()); Ok(InterceptOutcome::Continue(request)) } } @@ -159,8 +159,8 @@ fn test_request_config() -> OutgoingRequestConfig { OutgoingRequestConfig { use_tls: false, connect_timeout: Duration::from_millis(10), - first_byte_timeout: Duration::from_millis(10), - between_bytes_timeout: Duration::from_millis(10), + first_byte_timeout: Duration::from_millis(0), + between_bytes_timeout: Duration::from_millis(0), } }