Skip to content

Commit dcfd0c5

Browse files
committed
factor-outbound-http: Refactor ConnectOptions
Signed-off-by: Lann Martin <[email protected]>
1 parent 7d73906 commit dcfd0c5

File tree

1 file changed

+73
-81
lines changed
  • crates/factor-outbound-http/src

1 file changed

+73
-81
lines changed

crates/factor-outbound-http/src/wasi.rs

Lines changed: 73 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -172,11 +172,9 @@ impl RequestSender {
172172
}
173173
span.record("url.full", uri.to_string());
174174

175+
// If the current span has opentelemetry trace context, inject it into the request
175176
spin_telemetry::inject_trace_context(&mut request);
176177

177-
let host = request.uri().host().unwrap_or_default();
178-
let tls_client_config = self.component_tls_configs.get_client_config(host).clone();
179-
180178
let is_self_request = request
181179
.uri()
182180
.authority()
@@ -243,34 +241,37 @@ impl RequestSender {
243241
}
244242
}
245243

244+
// Backfill span fields after potentially updating the URL in the interceptor
246245
let authority = request.uri().authority().context("authority not set")?;
247246
span.record("server.address", authority.host());
248247
if let Some(port) = authority.port() {
249248
span.record("server.port", port.as_u16());
250249
}
251250

251+
let tls_client_config = if use_tls {
252+
let host = request.uri().host().unwrap_or_default();
253+
Some(self.component_tls_configs.get_client_config(host).clone())
254+
} else {
255+
None
256+
};
257+
252258
let resp = CONNECT_OPTIONS.scope(
253259
ConnectOptions {
254260
blocked_networks: self.blocked_networks,
255261
connect_timeout,
262+
tls_client_config,
256263
},
257264
async move {
258265
if use_tls {
259-
TLS_CLIENT_CONFIG
260-
.scope(tls_client_config, async move {
261-
self.http_clients.https.request(request).await
262-
})
263-
.await
266+
self.http_clients.https.request(request).await
264267
} else {
265-
let use_http2 = std::env::var_os("SPIN_OUTBOUND_H2C_PRIOR_KNOWLEDGE")
266-
.is_some_and(|v| {
267-
request
268-
.uri()
269-
.authority()
270-
.is_some_and(|authority| authority.as_str() == v)
271-
});
272-
273-
if use_http2 {
268+
// For development purposes, allow configuring plaintext HTTP/2 for a specific host.
269+
let h2c_prior_knowledge_host =
270+
std::env::var("SPIN_OUTBOUND_H2C_PRIOR_KNOWLEDGE").ok();
271+
let use_h2c = h2c_prior_knowledge_host.as_deref()
272+
== request.uri().authority().map(|a| a.as_str());
273+
274+
if use_h2c {
274275
self.http_clients.http2.request(request).await
275276
} else {
276277
self.http_clients.http1.request(request).await
@@ -325,73 +326,84 @@ impl HttpClients {
325326
}
326327
}
327328

328-
#[derive(Clone)]
329-
struct ConnectOptions {
330-
blocked_networks: BlockedNetworks,
331-
connect_timeout: Duration,
332-
}
333-
334329
// We must use task-local variables for these config options when using
335330
// `hyper_util::client::legacy::Client::request` because there's no way to plumb
336331
// them through as parameters. Moreover, if there's already a pooled connection
337332
// ready, we'll reuse that and ignore these options anyway.
338333
tokio::task_local! {
339334
static CONNECT_OPTIONS: ConnectOptions;
340-
static TLS_CLIENT_CONFIG: TlsClientConfig;
341335
}
342336

343-
async fn connect_tcp(uri: Uri, default_port: u16) -> Result<(TcpStream, String), ErrorCode> {
344-
let authority_str = if let Some(authority) = uri.authority() {
345-
if authority.port().is_some() {
346-
authority.to_string()
347-
} else {
348-
format!("{authority}:{default_port}")
337+
#[derive(Clone)]
338+
struct ConnectOptions {
339+
blocked_networks: BlockedNetworks,
340+
connect_timeout: Duration,
341+
tls_client_config: Option<TlsClientConfig>,
342+
}
343+
344+
impl ConnectOptions {
345+
async fn connect_tcp(&self, uri: &Uri, default_port: u16) -> Result<TcpStream, ErrorCode> {
346+
let host = uri.host().ok_or(ErrorCode::HttpRequestUriInvalid)?;
347+
let host_and_port = (host, uri.port_u16().unwrap_or(default_port));
348+
349+
let mut socket_addrs = tokio::net::lookup_host(host_and_port)
350+
.await
351+
.map_err(|_| dns_error("address not available".into(), 0))?
352+
.collect::<Vec<_>>();
353+
354+
// Remove blocked IPs
355+
let blocked_addrs = self.blocked_networks.remove_blocked(&mut socket_addrs);
356+
if socket_addrs.is_empty() && !blocked_addrs.is_empty() {
357+
tracing::error!(
358+
"error.type" = "destination_ip_prohibited",
359+
?blocked_addrs,
360+
"all destination IP(s) prohibited by runtime config"
361+
);
362+
return Err(ErrorCode::DestinationIpProhibited);
349363
}
350-
} else {
351-
return Err(ErrorCode::HttpRequestUriInvalid);
352-
};
353-
354-
let ConnectOptions {
355-
blocked_networks,
356-
connect_timeout,
357-
} = CONNECT_OPTIONS.get();
358-
359-
let mut socket_addrs = tokio::net::lookup_host(&authority_str)
360-
.await
361-
.map_err(|_| dns_error("address not available".into(), 0))?
362-
.collect::<Vec<_>>();
363-
364-
// Remove blocked IPs
365-
let blocked_addrs = blocked_networks.remove_blocked(&mut socket_addrs);
366-
if socket_addrs.is_empty() && !blocked_addrs.is_empty() {
367-
tracing::error!(
368-
"error.type" = "destination_ip_prohibited",
369-
?blocked_addrs,
370-
"all destination IP(s) prohibited by runtime config"
371-
);
372-
return Err(ErrorCode::DestinationIpProhibited);
373-
}
374364

375-
Ok((
376-
timeout(connect_timeout, TcpStream::connect(socket_addrs.as_slice()))
365+
timeout(self.connect_timeout, TcpStream::connect(&*socket_addrs))
377366
.await
378367
.map_err(|_| ErrorCode::ConnectionTimeout)?
379368
.map_err(|err| match err.kind() {
380369
std::io::ErrorKind::AddrNotAvailable => {
381370
dns_error("address not available".into(), 0)
382371
}
383372
_ => ErrorCode::ConnectionRefused,
384-
})?,
385-
authority_str,
386-
))
373+
})
374+
}
375+
376+
async fn connect_tls(
377+
&self,
378+
uri: &Uri,
379+
default_port: u16,
380+
) -> Result<TlsStream<TcpStream>, ErrorCode> {
381+
let tcp_stream = self.connect_tcp(uri, default_port).await?;
382+
383+
let mut tls_client_config = self.tls_client_config.as_deref().unwrap().clone();
384+
tls_client_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
385+
386+
let connector = tokio_rustls::TlsConnector::from(Arc::new(tls_client_config));
387+
let domain = rustls::pki_types::ServerName::try_from(uri.host().unwrap())
388+
.map_err(|e| {
389+
tracing::warn!("dns lookup error: {e:?}");
390+
dns_error("invalid dns name".to_string(), 0)
391+
})?
392+
.to_owned();
393+
connector.connect(domain, tcp_stream).await.map_err(|e| {
394+
tracing::warn!("tls protocol error: {e:?}");
395+
ErrorCode::TlsProtocolError
396+
})
397+
}
387398
}
388399

389400
#[derive(Clone)]
390401
struct HttpConnector;
391402

392403
impl HttpConnector {
393404
async fn connect(uri: Uri) -> Result<TokioIo<TcpStream>, ErrorCode> {
394-
Ok(TokioIo::new(connect_tcp(uri, 80).await?.0))
405+
let stream = CONNECT_OPTIONS.get().connect_tcp(&uri, 80).await?;
406+
Ok(TokioIo::new(stream))
395407
}
396408
}
397409

@@ -414,27 +426,7 @@ struct HttpsConnector;
414426

415427
impl HttpsConnector {
416428
async fn connect(uri: Uri) -> Result<TokioIo<RustlsStream>, ErrorCode> {
417-
use rustls::pki_types::ServerName;
418-
419-
let (tcp_stream, authority_str) = connect_tcp(uri, 443).await?;
420-
421-
let mut tls_client_config = (*TLS_CLIENT_CONFIG.get()).clone();
422-
tls_client_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
423-
424-
let connector = tokio_rustls::TlsConnector::from(Arc::new(tls_client_config));
425-
let mut parts = authority_str.split(':');
426-
let host = parts.next().unwrap_or(&authority_str);
427-
let domain = ServerName::try_from(host)
428-
.map_err(|e| {
429-
tracing::warn!("dns lookup error: {e:?}");
430-
dns_error("invalid dns name".to_string(), 0)
431-
})?
432-
.to_owned();
433-
let stream = connector.connect(domain, tcp_stream).await.map_err(|e| {
434-
tracing::warn!("tls protocol error: {e:?}");
435-
ErrorCode::TlsProtocolError
436-
})?;
437-
429+
let stream = CONNECT_OPTIONS.get().connect_tls(&uri, 443).await?;
438430
Ok(TokioIo::new(RustlsStream(stream)))
439431
}
440432
}

0 commit comments

Comments
 (0)