Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions crates/factor-outbound-http/src/intercept.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::net::SocketAddr;

use http::{Request, Response};
use http_body_util::{BodyExt, Full};
use spin_world::async_trait;
Expand Down Expand Up @@ -39,7 +41,7 @@ pub enum InterceptOutcome {
pub struct InterceptRequest {
inner: Request<()>,
body: InterceptBody,
pub(crate) override_connect_host: Option<String>,
pub(crate) override_connect_addr: Option<SocketAddr>,
}

enum InterceptBody {
Expand All @@ -48,16 +50,17 @@ enum InterceptBody {
}

impl InterceptRequest {
/// Overrides the host that will be connected to for this outbound request.
/// Overrides the IP and port that will be connected to for this outbound
/// request.
///
/// This override does not have any effect on TLS server name checking or
/// HTTP authority / host headers.
///
/// This host will not be checked against `allowed_outbound_hosts`; if that
/// check should occur it must be performed by the interceptor. The resolved
/// IP addresses from this host will be checked against blocked IP networks.
pub fn override_connect_host(&mut self, host: impl Into<String>) {
self.override_connect_host = Some(host.into())
/// The IP will be checked against blocked IP networks but it will not be
/// checked against `allowed_outbound_hosts`; if that check needs to occur
/// it must be performed by the interceptor.
pub fn override_connect_addr(&mut self, endpoint: SocketAddr) {
self.override_connect_addr = Some(endpoint);
}

pub fn into_hyper_request(self) -> Request<HyperBody> {
Expand Down Expand Up @@ -94,7 +97,7 @@ impl From<Request<HyperBody>> for InterceptRequest {
Self {
inner: Request::from_parts(parts, ()),
body: InterceptBody::Hyper(body),
override_connect_host: None,
override_connect_addr: None,
}
}
}
Expand All @@ -105,7 +108,7 @@ impl From<Request<Vec<u8>>> for InterceptRequest {
Self {
inner: Request::from_parts(parts, ()),
body: InterceptBody::Vec(body),
override_connect_host: None,
override_connect_addr: None,
}
}
}
Expand Down
66 changes: 38 additions & 28 deletions crates/factor-outbound-http/src/wasi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::{
error::Error,
future::Future,
io::IoSlice,
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll},
Expand Down Expand Up @@ -113,7 +114,7 @@ impl WasiHttpView for WasiHttpImplInner<'_> {
http.response.status_code = Empty,
server.address = Empty,
server.port = Empty,
),
)
)]
fn send_request(
&mut self,
Expand Down Expand Up @@ -166,12 +167,12 @@ impl RequestSender {
spin_telemetry::inject_trace_context(&mut request);

// Run any configured request interceptor
let mut override_connect_host = None;
let mut override_connect_addr = None;
if let Some(interceptor) = &self.request_interceptor {
let intercept_request = std::mem::take(&mut request).into();
match interceptor.intercept(intercept_request).await? {
InterceptOutcome::Continue(mut req) => {
override_connect_host = req.override_connect_host.take();
override_connect_addr = req.override_connect_addr.take();
request = req.into_hyper_request();
}
InterceptOutcome::Complete(resp) => {
Expand All @@ -186,17 +187,19 @@ impl RequestSender {
}

// Backfill span fields after potentially updating the URL in the interceptor
if let Some(authority) = request.uri().authority() {
let span = tracing::Span::current();
let host = override_connect_host.as_deref().unwrap_or(authority.host());
span.record("server.address", host);
if let Some(port) = authority.port() {
span.record("server.port", port.as_u16());
let span = tracing::Span::current();
if let Some(addr) = override_connect_addr {
span.record("server.address", addr.ip().to_string());
span.record("server.port", addr.port());
} else if let Some(authority) = request.uri().authority() {
span.record("server.address", authority.host());
if let Some(port) = authority.port_u16() {
span.record("server.port", port);
}
}

Ok(self
.send_request(request, config, override_connect_host)
.send_request(request, config, override_connect_addr)
.await?)
}

Expand Down Expand Up @@ -275,7 +278,7 @@ impl RequestSender {
self,
request: OutgoingRequest,
config: OutgoingRequestConfig,
override_connect_host: Option<String>,
override_connect_addr: Option<SocketAddr>,
) -> Result<IncomingResponse, ErrorCode> {
let OutgoingRequestConfig {
use_tls,
Expand All @@ -296,7 +299,7 @@ impl RequestSender {
blocked_networks: self.blocked_networks,
connect_timeout,
tls_client_config,
override_connect_host,
override_connect_addr,
},
async move {
if use_tls {
Expand Down Expand Up @@ -376,26 +379,33 @@ struct ConnectOptions {
blocked_networks: BlockedNetworks,
connect_timeout: Duration,
tls_client_config: Option<TlsClientConfig>,
override_connect_host: Option<String>,
override_connect_addr: Option<SocketAddr>,
}

impl ConnectOptions {
async fn connect_tcp(&self, uri: &Uri, default_port: u16) -> Result<TcpStream, ErrorCode> {
let host = self
.override_connect_host
.as_deref()
.or(uri.host())
.ok_or(ErrorCode::HttpRequestUriInvalid)?;
let host_and_port = (host, uri.port_u16().unwrap_or(default_port));

let mut socket_addrs = tokio::net::lookup_host(host_and_port)
.await
.map_err(|err| {
tracing::debug!(?host_and_port, ?err, "Error resolving host");
dns_error("address not available".into(), 0)
})?
.collect::<Vec<_>>();
tracing::debug!(?host_and_port, ?socket_addrs, "Resolved host");
let mut socket_addrs = match self.override_connect_addr {
Some(override_connect_addr) => vec![override_connect_addr],
None => {
let authority = uri.authority().ok_or(ErrorCode::HttpRequestUriInvalid)?;

let host_and_port = if authority.port().is_some() {
authority.as_str().to_string()
} else {
format!("{}:{}", authority.as_str(), default_port)
};

let socket_addrs = tokio::net::lookup_host(&host_and_port)
.await
.map_err(|err| {
tracing::debug!(?host_and_port, ?err, "Error resolving host");
dns_error("address not available".into(), 0)
})?
.collect::<Vec<_>>();
tracing::debug!(?host_and_port, ?socket_addrs, "Resolved host");
socket_addrs
}
};

// Remove blocked IPs
let blocked_addrs = self.blocked_networks.remove_blocked(&mut socket_addrs);
Expand Down
18 changes: 9 additions & 9 deletions crates/factor-outbound-http/tests/factor_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct TestFactors {
http: OutboundHttpFactor,
}

#[tokio::test]
#[tokio::test(flavor = "multi_thread")]
async fn allowed_host_is_allowed() -> anyhow::Result<()> {
let mut state = test_instance_state("https://*", true).await?;
let mut wasi_http = OutboundHttpFactor::get_wasi_http_impl(&mut state).unwrap();
Expand All @@ -36,7 +36,7 @@ async fn allowed_host_is_allowed() -> anyhow::Result<()> {
Ok(())
}

#[tokio::test]
#[tokio::test(flavor = "multi_thread")]
async fn self_request_smoke_test() -> anyhow::Result<()> {
let mut state = test_instance_state("http://self", true).await?;
// [100::] is the IPv6 "Discard Prefix", which should always fail
Expand All @@ -52,7 +52,7 @@ async fn self_request_smoke_test() -> anyhow::Result<()> {
Ok(())
}

#[tokio::test]
#[tokio::test(flavor = "multi_thread")]
async fn disallowed_host_fails() -> anyhow::Result<()> {
let mut state = test_instance_state("https://allowed.test", true).await?;
let mut wasi_http = OutboundHttpFactor::get_wasi_http_impl(&mut state).unwrap();
Expand All @@ -67,7 +67,7 @@ async fn disallowed_host_fails() -> anyhow::Result<()> {
Ok(())
}

#[tokio::test]
#[tokio::test(flavor = "multi_thread")]
async fn disallowed_private_ips_fails() -> anyhow::Result<()> {
async fn run_test(allow_private_ips: bool) -> anyhow::Result<()> {
let mut state = test_instance_state("http://*", allow_private_ips).await?;
Expand Down Expand Up @@ -100,8 +100,8 @@ async fn disallowed_private_ips_fails() -> anyhow::Result<()> {
Ok(())
}

#[tokio::test]
async fn override_connect_host_disallowed_private_ip_fails() -> anyhow::Result<()> {
#[tokio::test(flavor = "multi_thread")]
async fn override_connect_addr_disallowed_private_ip_fails() -> anyhow::Result<()> {
let mut state = test_instance_state("http://*", false).await?;
state.http.set_request_interceptor({
struct Interceptor;
Expand All @@ -111,7 +111,7 @@ async fn override_connect_host_disallowed_private_ip_fails() -> anyhow::Result<(
&self,
mut request: InterceptRequest,
) -> wasmtime_wasi_http::HttpResult<InterceptOutcome> {
request.override_connect_host("localhost");
request.override_connect_addr("[::1]:80".parse().unwrap());
Ok(InterceptOutcome::Continue(request))
}
}
Expand Down Expand Up @@ -159,8 +159,8 @@ fn test_request_config() -> OutgoingRequestConfig {
OutgoingRequestConfig {
use_tls: false,
connect_timeout: Duration::from_millis(10),
first_byte_timeout: Duration::from_millis(10),
between_bytes_timeout: Duration::from_millis(10),
first_byte_timeout: Duration::from_millis(0),
between_bytes_timeout: Duration::from_millis(0),
}
}

Expand Down
Loading