Skip to content

Commit 1b74452

Browse files
committed
parse hosts as Authority and use that as key for client tls opts
Signed-off-by: Rajat Jindal <[email protected]>
1 parent 19e4dff commit 1b74452

File tree

6 files changed

+43
-21
lines changed

6 files changed

+43
-21
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/trigger-http/src/lib.rs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use std::{
1818
use anyhow::{Context, Result};
1919
use async_trait::async_trait;
2020
use clap::Args;
21-
use http::{header::HOST, uri::Scheme, HeaderValue, StatusCode, Uri};
21+
use http::{header::HOST, uri::Authority, uri::Scheme, HeaderValue, StatusCode, Uri};
2222
use http_body_util::BodyExt;
2323
use hyper::{
2424
body::{Bytes, Incoming},
@@ -587,7 +587,7 @@ pub struct HttpRuntimeData {
587587
origin: Option<String>,
588588
chained_handler: Option<ChainedRequestHandler>,
589589
/// If provided, these options used for client cert auth
590-
client_tls_opts: Option<HashMap<String, ParsedClientTlsOpts>>,
590+
client_tls_opts: Option<HashMap<Authority, ParsedClientTlsOpts>>,
591591
/// The hosts this app is allowed to make outbound requests to
592592
allowed_hosts: AllowedHostsConfig,
593593
}
@@ -1018,9 +1018,9 @@ pub async fn send_request_handler(
10181018
first_byte_timeout,
10191019
between_bytes_timeout,
10201020
}: wasmtime_wasi_http::types::OutgoingRequestConfig,
1021-
client_tls_opts: Option<HashMap<String, ParsedClientTlsOpts>>,
1021+
client_tls_opts: Option<HashMap<Authority, ParsedClientTlsOpts>>,
10221022
) -> Result<wasmtime_wasi_http::types::IncomingResponse, types::ErrorCode> {
1023-
let authority = if let Some(authority) = request.uri().authority() {
1023+
let authority_str = if let Some(authority) = request.uri().authority() {
10241024
if authority.port().is_some() {
10251025
authority.to_string()
10261026
} else {
@@ -1030,7 +1030,10 @@ pub async fn send_request_handler(
10301030
} else {
10311031
return Err(types::ErrorCode::HttpRequestUriInvalid);
10321032
};
1033-
let tcp_stream = timeout(connect_timeout, TcpStream::connect(&authority))
1033+
1034+
let authority = &authority_str.parse::<Authority>().unwrap();
1035+
1036+
let tcp_stream = timeout(connect_timeout, TcpStream::connect(&authority_str))
10341037
.await
10351038
.map_err(|_| types::ErrorCode::ConnectionTimeout)?
10361039
.map_err(|e| match e.kind() {
@@ -1062,8 +1065,8 @@ pub async fn send_request_handler(
10621065
use rustls::pki_types::ServerName;
10631066
let config = get_client_tls_config_for_authority(&authority, client_tls_opts);
10641067
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
1065-
let mut parts = authority.split(":");
1066-
let host = parts.next().unwrap_or(&authority);
1068+
let mut parts = authority_str.split(":");
1069+
let host = parts.next().unwrap_or(&authority_str);
10671070
let domain = ServerName::try_from(host)
10681071
.map_err(|e| {
10691072
tracing::warn!("dns lookup error: {e:?}");
@@ -1145,8 +1148,8 @@ pub async fn send_request_handler(
11451148
}
11461149

11471150
fn get_client_tls_config_for_authority(
1148-
authority: &String,
1149-
client_tls_opts: Option<HashMap<String, ParsedClientTlsOpts>>,
1151+
authority: &Authority,
1152+
client_tls_opts: Option<HashMap<Authority, ParsedClientTlsOpts>>,
11501153
) -> rustls::ClientConfig {
11511154
// derived from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs
11521155
let mut root_cert_store = rustls::RootCertStore {

crates/trigger/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ dirs = "4"
2222
futures = "0.3"
2323
indexmap = "1"
2424
ipnet = "2.9.0"
25+
http = "1.0.0"
2526
outbound-http = { path = "../outbound-http" }
2627
outbound-redis = { path = "../outbound-redis" }
2728
outbound-mqtt = { path = "../outbound-mqtt" }

crates/trigger/src/lib.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use std::{collections::HashMap, marker::PhantomData};
88

99
use anyhow::{Context, Result};
1010
pub use async_trait::async_trait;
11+
use http::uri::Authority;
1112
use runtime_config::llm::LLmOptions;
1213
use serde::de::DeserializeOwned;
1314

@@ -291,8 +292,8 @@ pub struct TriggerAppEngine<Executor: TriggerExecutor> {
291292
component_instance_pres: HashMap<String, Executor::InstancePre>,
292293
// Resolver for value template expressions
293294
resolver: std::sync::Arc<spin_expressions::PreparedResolver>,
294-
// Map of { Component ID -> Map of { Host -> ParsedClientTlsOpts} }
295-
client_tls_opts: HashMap<String, HashMap<String, ParsedClientTlsOpts>>,
295+
// Map of { Component ID -> Map of { Authority -> ParsedClientTlsOpts} }
296+
client_tls_opts: HashMap<String, HashMap<Authority, ParsedClientTlsOpts>>,
296297
}
297298

298299
impl<Executor: TriggerExecutor> TriggerAppEngine<Executor> {
@@ -304,7 +305,7 @@ impl<Executor: TriggerExecutor> TriggerAppEngine<Executor> {
304305
app: OwnedApp,
305306
hooks: Vec<Box<dyn TriggerHooks>>,
306307
resolver: &std::sync::Arc<spin_expressions::PreparedResolver>,
307-
client_tls_opts: HashMap<String, HashMap<String, ParsedClientTlsOpts>>,
308+
client_tls_opts: HashMap<String, HashMap<Authority, ParsedClientTlsOpts>>,
308309
) -> Result<Self>
309310
where
310311
<Executor as TriggerExecutor>::TriggerConfig: DeserializeOwned,
@@ -440,7 +441,7 @@ impl<Executor: TriggerExecutor> TriggerAppEngine<Executor> {
440441
pub fn get_client_tls_opts(
441442
&self,
442443
component_id: &str,
443-
) -> Option<HashMap<String, ParsedClientTlsOpts>> {
444+
) -> Option<HashMap<Authority, ParsedClientTlsOpts>> {
444445
self.client_tls_opts.get(component_id).cloned()
445446
}
446447

crates/trigger/src/runtime_config.rs

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use std::{
1212
};
1313

1414
use anyhow::{Context, Result};
15+
use http::uri::Authority;
1516
use serde::Deserialize;
1617
use spin_common::ui::quoted_path;
1718
use spin_sqlite::Connection;
@@ -185,16 +186,18 @@ impl RuntimeConfig {
185186

186187
// returns the client tls options in form of nested
187188
// HashMap of { Component ID -> HashMap of { Host -> ParsedClientTlsOpts} }
188-
pub fn client_tls_opts(&self) -> Result<HashMap<String, HashMap<String, ParsedClientTlsOpts>>> {
189-
let mut components_map: HashMap<String, HashMap<String, ParsedClientTlsOpts>> =
189+
pub fn client_tls_opts(
190+
&self,
191+
) -> Result<HashMap<String, HashMap<Authority, ParsedClientTlsOpts>>> {
192+
let mut components_map: HashMap<String, HashMap<Authority, ParsedClientTlsOpts>> =
190193
HashMap::new();
191194

192195
for opt_layer in self.opts_layers() {
193196
for opts in &opt_layer.client_tls_opts {
194197
let parsed = parse_client_tls_opts(opts).context("parsing client tls options")?;
195198
for component_id in &opts.component_ids {
196-
let mut hostmap: HashMap<String, ParsedClientTlsOpts> = HashMap::new();
197-
for host in &opts.hosts {
199+
let mut hostmap: HashMap<Authority, ParsedClientTlsOpts> = HashMap::new();
200+
for host in &parsed.hosts {
198201
hostmap.insert(host.to_owned(), parsed.clone());
199202
}
200203

@@ -547,7 +550,7 @@ mod tests {
547550
#[derive(Debug, Clone)]
548551
pub struct ParsedClientTlsOpts {
549552
pub components: Vec<String>,
550-
pub hosts: Vec<String>,
553+
pub hosts: Vec<Authority>,
551554
pub custom_root_ca: Option<Vec<rustls_pki_types::CertificateDer<'static>>>,
552555
pub cert_chain: Option<Vec<rustls_pki_types::CertificateDer<'static>>>,
553556
pub private_key: Option<Arc<rustls_pki_types::PrivateKeyDer<'static>>>,
@@ -572,8 +575,18 @@ fn parse_client_tls_opts(inp: &ClientTlsOpts) -> Result<ParsedClientTlsOpts, any
572575
None => None,
573576
};
574577

578+
let parsed_hosts: Vec<Authority> = inp
579+
.hosts
580+
.clone()
581+
.into_iter()
582+
.map(|s| {
583+
s.parse::<Authority>()
584+
.map_err(|e| anyhow::anyhow!("failed to parse uri {:?}", e))
585+
})
586+
.collect::<Result<Vec<Authority>, anyhow::Error>>()?;
587+
575588
Ok(ParsedClientTlsOpts {
576-
hosts: inp.hosts.clone(),
589+
hosts: parsed_hosts,
577590
components: inp.component_ids.clone(),
578591
custom_root_ca,
579592
cert_chain,

crates/trigger/src/runtime_config/client_tls.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ use anyhow::Context;
22
use rustls_pemfile::private_key;
33
use std::io;
44
use std::io::Cursor;
5-
use std::{fs, path::{Path, PathBuf}};
5+
use std::{
6+
fs,
7+
path::{Path, PathBuf},
8+
};
69

710
#[derive(Debug, serde::Deserialize)]
811
#[serde(rename_all = "snake_case", tag = "type")]
@@ -36,4 +39,4 @@ pub fn load_key(
3639
))
3740
.map_err(|_| anyhow::anyhow!("invalid input"))
3841
.map(|keys| keys.unwrap())
39-
}
42+
}

0 commit comments

Comments
 (0)