Skip to content

Commit cbb2cb9

Browse files
authored
Merge pull request #2433 from calebschoepp/no-traps-in-host-components
ref(*): Refactor host components to avoid returning Result<Result<T>> if they don't trap
2 parents 9f6afcd + 3398132 commit cbb2cb9

File tree

13 files changed

+888
-854
lines changed

13 files changed

+888
-854
lines changed

Cargo.lock

Lines changed: 399 additions & 370 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,10 @@ hyper = { version = "1.0.0", features = ["full"] }
127127
reqwest = { version = "0.11", features = ["stream", "blocking"] }
128128
tracing = { version = "0.1", features = ["log"] }
129129

130-
wasi-common-preview1 = { version = "18.0.1", package = "wasi-common" }
131-
wasmtime = "18.0.1"
132-
wasmtime-wasi = { version = "18.0.1", features = ["tokio"] }
133-
wasmtime-wasi-http = "18.0.1"
130+
wasi-common-preview1 = { version = "18.0.4", package = "wasi-common" }
131+
wasmtime = "18.0.4"
132+
wasmtime-wasi = { version = "18.0.4", features = ["tokio"] }
133+
wasmtime-wasi-http = "18.0.4"
134134

135135
spin-componentize = { path = "crates/componentize" }
136136

crates/llm/src/lib.rs

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,11 @@ impl v2::Host for LlmDispatch {
3939
model: v2::InferencingModel,
4040
prompt: String,
4141
params: Option<v2::InferencingParams>,
42-
) -> anyhow::Result<Result<v2::InferencingResult, v2::Error>> {
42+
) -> Result<v2::InferencingResult, v2::Error> {
4343
if !self.allowed_models.contains(&model) {
44-
return Ok(Err(access_denied_error(&model)));
44+
return Err(access_denied_error(&model));
4545
}
46-
Ok(self
47-
.engine
46+
self.engine
4847
.infer(
4948
model,
5049
prompt,
@@ -57,18 +56,22 @@ impl v2::Host for LlmDispatch {
5756
top_p: 0.9,
5857
}),
5958
)
60-
.await)
59+
.await
6160
}
6261

6362
async fn generate_embeddings(
6463
&mut self,
6564
m: v1::EmbeddingModel,
6665
data: Vec<String>,
67-
) -> anyhow::Result<Result<v2::EmbeddingsResult, v2::Error>> {
66+
) -> Result<v2::EmbeddingsResult, v2::Error> {
6867
if !self.allowed_models.contains(&m) {
69-
return Ok(Err(access_denied_error(&m)));
68+
return Err(access_denied_error(&m));
7069
}
71-
Ok(self.engine.generate_embeddings(m, data).await)
70+
self.engine.generate_embeddings(m, data).await
71+
}
72+
73+
fn convert_error(&mut self, error: v2::Error) -> anyhow::Result<v2::Error> {
74+
Ok(error)
7275
}
7376
}
7477

@@ -79,24 +82,26 @@ impl v1::Host for LlmDispatch {
7982
model: v1::InferencingModel,
8083
prompt: String,
8184
params: Option<v1::InferencingParams>,
82-
) -> anyhow::Result<Result<v1::InferencingResult, v1::Error>> {
83-
Ok(
84-
<Self as v2::Host>::infer(self, model, prompt, params.map(Into::into))
85-
.await?
86-
.map(Into::into)
87-
.map_err(Into::into),
88-
)
85+
) -> Result<v1::InferencingResult, v1::Error> {
86+
<Self as v2::Host>::infer(self, model, prompt, params.map(Into::into))
87+
.await
88+
.map(Into::into)
89+
.map_err(Into::into)
8990
}
9091

9192
async fn generate_embeddings(
9293
&mut self,
9394
model: v1::EmbeddingModel,
9495
data: Vec<String>,
95-
) -> anyhow::Result<Result<v1::EmbeddingsResult, v1::Error>> {
96-
Ok(<Self as v2::Host>::generate_embeddings(self, model, data)
97-
.await?
96+
) -> Result<v1::EmbeddingsResult, v1::Error> {
97+
<Self as v2::Host>::generate_embeddings(self, model, data)
98+
.await
9899
.map(Into::into)
99-
.map_err(Into::into))
100+
.map_err(Into::into)
101+
}
102+
103+
fn convert_error(&mut self, error: v1::Error) -> anyhow::Result<v1::Error> {
104+
Ok(error)
100105
}
101106
}
102107

crates/outbound-http/src/host_impl.rs

Lines changed: 60 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use spin_core::async_trait;
55
use spin_outbound_networking::{AllowedHostsConfig, OutboundUrl};
66
use spin_world::v1::{
77
http as outbound_http,
8-
http_types::{Headers, HttpError, Method, Request, Response},
8+
http_types::{self, Headers, HttpError, Method, Request, Response},
99
};
1010
use tracing::{field::Empty, instrument, Level};
1111

@@ -42,72 +42,75 @@ impl outbound_http::Host for OutboundHttp {
4242
#[instrument(name = "spin_outbound_http.send_request", skip_all, err(level = Level::INFO),
4343
fields(otel.kind = "client", url.full = Empty, http.request.method = Empty,
4444
http.response.status_code = Empty, otel.name = Empty, server.address = Empty, server.port = Empty))]
45-
async fn send_request(&mut self, req: Request) -> Result<Result<Response, HttpError>> {
46-
Ok(async {
47-
let current_span = tracing::Span::current();
48-
let method = format!("{:?}", req.method)
49-
.strip_prefix("Method::")
50-
.unwrap_or("_OTHER")
51-
.to_uppercase();
52-
current_span.record("otel.name", method.clone());
53-
current_span.record("url.full", req.uri.clone());
54-
current_span.record("http.request.method", method);
55-
if let Ok(uri) = req.uri.parse::<Uri>() {
56-
if let Some(authority) = uri.authority() {
57-
current_span.record("server.address", authority.host());
58-
if let Some(port) = authority.port() {
59-
current_span.record("server.port", port.as_u16());
60-
}
45+
async fn send_request(&mut self, req: Request) -> Result<Response, HttpError> {
46+
let current_span = tracing::Span::current();
47+
let method = format!("{:?}", req.method)
48+
.strip_prefix("Method::")
49+
.unwrap_or("_OTHER")
50+
.to_uppercase();
51+
current_span.record("otel.name", method.clone());
52+
current_span.record("url.full", req.uri.clone());
53+
current_span.record("http.request.method", method);
54+
if let Ok(uri) = req.uri.parse::<Uri>() {
55+
if let Some(authority) = uri.authority() {
56+
current_span.record("server.address", authority.host());
57+
if let Some(port) = authority.port() {
58+
current_span.record("server.port", port.as_u16());
6159
}
6260
}
61+
}
6362

64-
tracing::log::trace!("Attempting to send outbound HTTP request to {}", req.uri);
65-
if !self
66-
.is_allowed(&req.uri)
67-
.map_err(|_| HttpError::RuntimeError)?
68-
{
69-
tracing::log::info!("Destination not allowed: {}", req.uri);
70-
if let Some((scheme, host_and_port)) = scheme_host_and_port(&req.uri) {
71-
terminal::warn!("A component tried to make a HTTP request to non-allowed host '{host_and_port}'.");
72-
eprintln!("To allow requests, add 'allowed_outbound_hosts = [\"{scheme}://{host_and_port}\"]' to the manifest component section.");
73-
}
74-
return Err(HttpError::DestinationNotAllowed);
63+
tracing::log::trace!("Attempting to send outbound HTTP request to {}", req.uri);
64+
if !self
65+
.is_allowed(&req.uri)
66+
.map_err(|_| HttpError::RuntimeError)?
67+
{
68+
tracing::log::info!("Destination not allowed: {}", req.uri);
69+
if let Some((scheme, host_and_port)) = scheme_host_and_port(&req.uri) {
70+
terminal::warn!("A component tried to make a HTTP request to non-allowed host '{host_and_port}'.");
71+
eprintln!("To allow requests, add 'allowed_outbound_hosts = [\"{scheme}://{host_and_port}\"]' to the manifest component section.");
7572
}
73+
return Err(HttpError::DestinationNotAllowed);
74+
}
7675

77-
let method = method_from(req.method);
78-
79-
let abs_url = if req.uri.starts_with('/') {
80-
format!("{}{}", self.origin, req.uri)
81-
} else {
82-
req.uri.clone()
83-
};
76+
let method = method_from(req.method);
8477

85-
let req_url = reqwest::Url::parse(&abs_url).map_err(|_| HttpError::InvalidUrl)?;
78+
let abs_url = if req.uri.starts_with('/') {
79+
format!("{}{}", self.origin, req.uri)
80+
} else {
81+
req.uri.clone()
82+
};
8683

87-
let mut headers = request_headers(req.headers).map_err(|_| HttpError::RuntimeError)?;
88-
spin_telemetry::inject_trace_context(&mut headers);
89-
let body = req.body.unwrap_or_default().to_vec();
84+
let req_url = reqwest::Url::parse(&abs_url).map_err(|_| HttpError::InvalidUrl)?;
9085

91-
if !req.params.is_empty() {
92-
tracing::log::warn!("HTTP params field is deprecated");
93-
}
86+
let mut headers = request_headers(req.headers).map_err(|_| HttpError::RuntimeError)?;
87+
spin_telemetry::inject_trace_context(&mut headers);
88+
let body = req.body.unwrap_or_default().to_vec();
9489

95-
// Allow reuse of Client's internal connection pool for multiple requests
96-
// in a single component execution
97-
let client = self.client.get_or_insert_with(Default::default);
98-
99-
let resp = client
100-
.request(method, req_url)
101-
.headers(headers)
102-
.body(body)
103-
.send()
104-
.await
105-
.map_err(log_reqwest_error)?;
106-
tracing::log::trace!("Returning response from outbound request to {}", req.uri);
107-
current_span.record("http.response.status_code", resp.status().as_u16());
108-
response_from_reqwest(resp).await
90+
if !req.params.is_empty() {
91+
tracing::log::warn!("HTTP params field is deprecated");
10992
}
110-
.await)
93+
94+
// Allow reuse of Client's internal connection pool for multiple requests
95+
// in a single component execution
96+
let client = self.client.get_or_insert_with(Default::default);
97+
98+
let resp = client
99+
.request(method, req_url)
100+
.headers(headers)
101+
.body(body)
102+
.send()
103+
.await
104+
.map_err(log_reqwest_error)?;
105+
tracing::log::trace!("Returning response from outbound request to {}", req.uri);
106+
current_span.record("http.response.status_code", resp.status().as_u16());
107+
response_from_reqwest(resp).await
108+
}
109+
}
110+
111+
impl http_types::Host for OutboundHttp {
112+
fn convert_http_error(&mut self, error: HttpError) -> Result<HttpError> {
113+
Ok(error)
111114
}
112115
}
113116

crates/outbound-mqtt/src/lib.rs

Lines changed: 45 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -37,26 +37,27 @@ impl OutboundMqtt {
3737
username: String,
3838
password: String,
3939
keep_alive_interval: Duration,
40-
) -> Result<Result<Resource<MqttConnection>, Error>> {
41-
Ok(async {
42-
let mut conn_opts = rumqttc::MqttOptions::parse_url(address).map_err(|e| {
43-
tracing::error!("MQTT URL parse error: {e:?}");
44-
Error::InvalidAddress
45-
})?;
46-
conn_opts.set_credentials(username, password);
47-
conn_opts.set_keep_alive(keep_alive_interval);
48-
let (client, event_loop) = AsyncClient::new(conn_opts, MQTT_CHANNEL_CAP);
49-
50-
self.connections
51-
.push((client, event_loop))
52-
.map(Resource::new_own)
53-
.map_err(|_| Error::TooManyConnections)
54-
}
55-
.await)
40+
) -> Result<Resource<MqttConnection>, Error> {
41+
let mut conn_opts = rumqttc::MqttOptions::parse_url(address).map_err(|e| {
42+
tracing::error!("MQTT URL parse error: {e:?}");
43+
Error::InvalidAddress
44+
})?;
45+
conn_opts.set_credentials(username, password);
46+
conn_opts.set_keep_alive(keep_alive_interval);
47+
let (client, event_loop) = AsyncClient::new(conn_opts, MQTT_CHANNEL_CAP);
48+
49+
self.connections
50+
.push((client, event_loop))
51+
.map(Resource::new_own)
52+
.map_err(|_| Error::TooManyConnections)
5653
}
5754
}
5855

59-
impl v2::Host for OutboundMqtt {}
56+
impl v2::Host for OutboundMqtt {
57+
fn convert_error(&mut self, error: Error) -> Result<Error> {
58+
Ok(error)
59+
}
60+
}
6061

6162
#[async_trait]
6263
impl v2::HostConnection for OutboundMqtt {
@@ -67,11 +68,11 @@ impl v2::HostConnection for OutboundMqtt {
6768
username: String,
6869
password: String,
6970
keep_alive_interval: u64,
70-
) -> Result<Result<Resource<MqttConnection>, Error>> {
71+
) -> Result<Resource<MqttConnection>, Error> {
7172
if !self.is_address_allowed(&address) {
72-
return Ok(Err(v2::Error::ConnectionFailed(format!(
73+
return Err(v2::Error::ConnectionFailed(format!(
7374
"address {address} is not permitted"
74-
))));
75+
)));
7576
}
7677
self.establish_connection(
7778
address,
@@ -96,36 +97,33 @@ impl v2::HostConnection for OutboundMqtt {
9697
topic: String,
9798
payload: Vec<u8>,
9899
qos: Qos,
99-
) -> Result<Result<(), Error>> {
100-
Ok(async {
101-
let (client, eventloop) = self.get_conn(connection).await.map_err(other_error)?;
102-
let qos = convert_to_mqtt_qos_value(qos);
103-
104-
// Message published to EventLoop (not MQTT Broker)
105-
client
106-
.publish_bytes(topic, qos, false, payload.into())
100+
) -> Result<(), Error> {
101+
let (client, eventloop) = self.get_conn(connection).await.map_err(other_error)?;
102+
let qos = convert_to_mqtt_qos_value(qos);
103+
104+
// Message published to EventLoop (not MQTT Broker)
105+
client
106+
.publish_bytes(topic, qos, false, payload.into())
107+
.await
108+
.map_err(other_error)?;
109+
110+
// Poll event loop until outgoing publish event is iterated over to send the message to MQTT broker or capture/throw error.
111+
// We may revisit this later to manage long running connections, high throughput use cases and their issues in the connection pool.
112+
loop {
113+
let event = eventloop
114+
.poll()
107115
.await
108-
.map_err(other_error)?;
109-
110-
// Poll event loop until outgoing publish event is iterated over to send the message to MQTT broker or capture/throw error.
111-
// We may revisit this later to manage long running connections, high throughput use cases and their issues in the connection pool.
112-
loop {
113-
let event = eventloop
114-
.poll()
115-
.await
116-
.map_err(|err| v2::Error::ConnectionFailed(err.to_string()))?;
117-
118-
match (qos, event) {
119-
(QoS::AtMostOnce, Event::Outgoing(Outgoing::Publish(_)))
120-
| (QoS::AtLeastOnce, Event::Incoming(Incoming::PubAck(_)))
121-
| (QoS::ExactlyOnce, Event::Incoming(Incoming::PubComp(_))) => break,
122-
123-
(_, _) => continue,
124-
}
116+
.map_err(|err| v2::Error::ConnectionFailed(err.to_string()))?;
117+
118+
match (qos, event) {
119+
(QoS::AtMostOnce, Event::Outgoing(Outgoing::Publish(_)))
120+
| (QoS::AtLeastOnce, Event::Incoming(Incoming::PubAck(_)))
121+
| (QoS::ExactlyOnce, Event::Incoming(Incoming::PubComp(_))) => break,
122+
123+
(_, _) => continue,
125124
}
126-
Ok(())
127125
}
128-
.await)
126+
Ok(())
129127
}
130128

131129
fn drop(&mut self, connection: Resource<MqttConnection>) -> anyhow::Result<()> {

0 commit comments

Comments
 (0)