Skip to content

Commit 6ad9d0b

Browse files
committed
factor-outbound-http: Update override_connect_host
- Rename to override_connect_addr - Change from String to SocketAddr - Fix bug in host resolution for ipv6 addresses Signed-off-by: Lann Martin <[email protected]>
1 parent d3d0903 commit 6ad9d0b

File tree

3 files changed

+59
-46
lines changed

3 files changed

+59
-46
lines changed

crates/factor-outbound-http/src/intercept.rs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::net::SocketAddr;
2+
13
use http::{Request, Response};
24
use http_body_util::{BodyExt, Full};
35
use spin_world::async_trait;
@@ -39,7 +41,7 @@ pub enum InterceptOutcome {
3941
pub struct InterceptRequest {
4042
inner: Request<()>,
4143
body: InterceptBody,
42-
pub(crate) override_connect_host: Option<String>,
44+
pub(crate) override_connect_addr: Option<SocketAddr>,
4345
}
4446

4547
enum InterceptBody {
@@ -48,16 +50,17 @@ enum InterceptBody {
4850
}
4951

5052
impl InterceptRequest {
51-
/// Overrides the host that will be connected to for this outbound request.
53+
/// Overrides the IP and port that will be connected to for this outbound
54+
/// request.
5255
///
5356
/// This override does not have any effect on TLS server name checking or
5457
/// HTTP authority / host headers.
5558
///
56-
/// This host will not be checked against `allowed_outbound_hosts`; if that
57-
/// check should occur it must be performed by the interceptor. The resolved
58-
/// IP addresses from this host will be checked against blocked IP networks.
59-
pub fn override_connect_host(&mut self, host: impl Into<String>) {
60-
self.override_connect_host = Some(host.into())
59+
/// The IP will be checked against blocked IP networks but it will not be
60+
/// checked against `allowed_outbound_hosts`; if that check needs to occur
61+
/// it must be performed by the interceptor.
62+
pub fn override_connect_addr(&mut self, endpoint: SocketAddr) {
63+
self.override_connect_addr = Some(endpoint);
6164
}
6265

6366
pub fn into_hyper_request(self) -> Request<HyperBody> {
@@ -94,7 +97,7 @@ impl From<Request<HyperBody>> for InterceptRequest {
9497
Self {
9598
inner: Request::from_parts(parts, ()),
9699
body: InterceptBody::Hyper(body),
97-
override_connect_host: None,
100+
override_connect_addr: None,
98101
}
99102
}
100103
}
@@ -105,7 +108,7 @@ impl From<Request<Vec<u8>>> for InterceptRequest {
105108
Self {
106109
inner: Request::from_parts(parts, ()),
107110
body: InterceptBody::Vec(body),
108-
override_connect_host: None,
111+
override_connect_addr: None,
109112
}
110113
}
111114
}

crates/factor-outbound-http/src/wasi.rs

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::{
22
error::Error,
33
future::Future,
44
io::IoSlice,
5+
net::SocketAddr,
56
pin::Pin,
67
sync::Arc,
78
task::{Context, Poll},
@@ -113,7 +114,7 @@ impl WasiHttpView for WasiHttpImplInner<'_> {
113114
http.response.status_code = Empty,
114115
server.address = Empty,
115116
server.port = Empty,
116-
),
117+
)
117118
)]
118119
fn send_request(
119120
&mut self,
@@ -166,12 +167,12 @@ impl RequestSender {
166167
spin_telemetry::inject_trace_context(&mut request);
167168

168169
// Run any configured request interceptor
169-
let mut override_connect_host = None;
170+
let mut override_connect_addr = None;
170171
if let Some(interceptor) = &self.request_interceptor {
171172
let intercept_request = std::mem::take(&mut request).into();
172173
match interceptor.intercept(intercept_request).await? {
173174
InterceptOutcome::Continue(mut req) => {
174-
override_connect_host = req.override_connect_host.take();
175+
override_connect_addr = req.override_connect_addr.take();
175176
request = req.into_hyper_request();
176177
}
177178
InterceptOutcome::Complete(resp) => {
@@ -186,17 +187,19 @@ impl RequestSender {
186187
}
187188

188189
// Backfill span fields after potentially updating the URL in the interceptor
189-
if let Some(authority) = request.uri().authority() {
190-
let span = tracing::Span::current();
191-
let host = override_connect_host.as_deref().unwrap_or(authority.host());
192-
span.record("server.address", host);
193-
if let Some(port) = authority.port() {
194-
span.record("server.port", port.as_u16());
190+
let span = tracing::Span::current();
191+
if let Some(addr) = override_connect_addr {
192+
span.record("server.address", addr.ip().to_string());
193+
span.record("server.port", addr.port());
194+
} else if let Some(authority) = request.uri().authority() {
195+
span.record("server.address", authority.host());
196+
if let Some(port) = authority.port_u16() {
197+
span.record("server.port", port);
195198
}
196199
}
197200

198201
Ok(self
199-
.send_request(request, config, override_connect_host)
202+
.send_request(request, config, override_connect_addr)
200203
.await?)
201204
}
202205

@@ -275,7 +278,7 @@ impl RequestSender {
275278
self,
276279
request: OutgoingRequest,
277280
config: OutgoingRequestConfig,
278-
override_connect_host: Option<String>,
281+
override_connect_addr: Option<SocketAddr>,
279282
) -> Result<IncomingResponse, ErrorCode> {
280283
let OutgoingRequestConfig {
281284
use_tls,
@@ -296,7 +299,7 @@ impl RequestSender {
296299
blocked_networks: self.blocked_networks,
297300
connect_timeout,
298301
tls_client_config,
299-
override_connect_host,
302+
override_connect_addr,
300303
},
301304
async move {
302305
if use_tls {
@@ -376,26 +379,33 @@ struct ConnectOptions {
376379
blocked_networks: BlockedNetworks,
377380
connect_timeout: Duration,
378381
tls_client_config: Option<TlsClientConfig>,
379-
override_connect_host: Option<String>,
382+
override_connect_addr: Option<SocketAddr>,
380383
}
381384

382385
impl ConnectOptions {
383386
async fn connect_tcp(&self, uri: &Uri, default_port: u16) -> Result<TcpStream, ErrorCode> {
384-
let host = self
385-
.override_connect_host
386-
.as_deref()
387-
.or(uri.host())
388-
.ok_or(ErrorCode::HttpRequestUriInvalid)?;
389-
let host_and_port = (host, uri.port_u16().unwrap_or(default_port));
390-
391-
let mut socket_addrs = tokio::net::lookup_host(host_and_port)
392-
.await
393-
.map_err(|err| {
394-
tracing::debug!(?host_and_port, ?err, "Error resolving host");
395-
dns_error("address not available".into(), 0)
396-
})?
397-
.collect::<Vec<_>>();
398-
tracing::debug!(?host_and_port, ?socket_addrs, "Resolved host");
387+
let mut socket_addrs = match self.override_connect_addr {
388+
Some(override_connect_addr) => vec![override_connect_addr],
389+
None => {
390+
let authority = uri.authority().ok_or(ErrorCode::HttpRequestUriInvalid)?;
391+
392+
let host_and_port = if authority.port().is_some() {
393+
authority.as_str().to_string()
394+
} else {
395+
format!("{}:{}", authority.as_str(), default_port)
396+
};
397+
398+
let socket_addrs = tokio::net::lookup_host(&host_and_port)
399+
.await
400+
.map_err(|err| {
401+
tracing::debug!(?host_and_port, ?err, "Error resolving host");
402+
dns_error("address not available".into(), 0)
403+
})?
404+
.collect::<Vec<_>>();
405+
tracing::debug!(?host_and_port, ?socket_addrs, "Resolved host");
406+
socket_addrs
407+
}
408+
};
399409

400410
// Remove blocked IPs
401411
let blocked_addrs = self.blocked_networks.remove_blocked(&mut socket_addrs);

crates/factor-outbound-http/tests/factor_test.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ struct TestFactors {
2222
http: OutboundHttpFactor,
2323
}
2424

25-
#[tokio::test]
25+
#[tokio::test(flavor = "multi_thread")]
2626
async fn allowed_host_is_allowed() -> anyhow::Result<()> {
2727
let mut state = test_instance_state("https://*", true).await?;
2828
let mut wasi_http = OutboundHttpFactor::get_wasi_http_impl(&mut state).unwrap();
@@ -36,7 +36,7 @@ async fn allowed_host_is_allowed() -> anyhow::Result<()> {
3636
Ok(())
3737
}
3838

39-
#[tokio::test]
39+
#[tokio::test(flavor = "multi_thread")]
4040
async fn self_request_smoke_test() -> anyhow::Result<()> {
4141
let mut state = test_instance_state("http://self", true).await?;
4242
// [100::] is the IPv6 "Discard Prefix", which should always fail
@@ -52,7 +52,7 @@ async fn self_request_smoke_test() -> anyhow::Result<()> {
5252
Ok(())
5353
}
5454

55-
#[tokio::test]
55+
#[tokio::test(flavor = "multi_thread")]
5656
async fn disallowed_host_fails() -> anyhow::Result<()> {
5757
let mut state = test_instance_state("https://allowed.test", true).await?;
5858
let mut wasi_http = OutboundHttpFactor::get_wasi_http_impl(&mut state).unwrap();
@@ -67,7 +67,7 @@ async fn disallowed_host_fails() -> anyhow::Result<()> {
6767
Ok(())
6868
}
6969

70-
#[tokio::test]
70+
#[tokio::test(flavor = "multi_thread")]
7171
async fn disallowed_private_ips_fails() -> anyhow::Result<()> {
7272
async fn run_test(allow_private_ips: bool) -> anyhow::Result<()> {
7373
let mut state = test_instance_state("http://*", allow_private_ips).await?;
@@ -100,8 +100,8 @@ async fn disallowed_private_ips_fails() -> anyhow::Result<()> {
100100
Ok(())
101101
}
102102

103-
#[tokio::test]
104-
async fn override_connect_host_disallowed_private_ip_fails() -> anyhow::Result<()> {
103+
#[tokio::test(flavor = "multi_thread")]
104+
async fn override_connect_addr_disallowed_private_ip_fails() -> anyhow::Result<()> {
105105
let mut state = test_instance_state("http://*", false).await?;
106106
state.http.set_request_interceptor({
107107
struct Interceptor;
@@ -111,7 +111,7 @@ async fn override_connect_host_disallowed_private_ip_fails() -> anyhow::Result<(
111111
&self,
112112
mut request: InterceptRequest,
113113
) -> wasmtime_wasi_http::HttpResult<InterceptOutcome> {
114-
request.override_connect_host("localhost");
114+
request.override_connect_addr("[::1]:80".parse().unwrap());
115115
Ok(InterceptOutcome::Continue(request))
116116
}
117117
}
@@ -159,8 +159,8 @@ fn test_request_config() -> OutgoingRequestConfig {
159159
OutgoingRequestConfig {
160160
use_tls: false,
161161
connect_timeout: Duration::from_millis(10),
162-
first_byte_timeout: Duration::from_millis(10),
163-
between_bytes_timeout: Duration::from_millis(10),
162+
first_byte_timeout: Duration::from_millis(0),
163+
between_bytes_timeout: Duration::from_millis(0),
164164
}
165165
}
166166

0 commit comments

Comments
 (0)