Skip to content

Commit 871918b

Browse files
authored
Merge pull request #2722 from karthik2804/mqtt_factors
Add outbound MQTT factor
2 parents 69c3cd8 + 398cd3f commit 871918b

File tree

5 files changed

+435
-0
lines changed

5 files changed

+435
-0
lines changed

Cargo.lock

Lines changed: 17 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
[package]
2+
name = "spin-factor-outbound-mqtt"
3+
version = { workspace = true }
4+
authors = { workspace = true }
5+
edition = { workspace = true }
6+
7+
[dependencies]
8+
anyhow = "1.0"
9+
rumqttc = { version = "0.24", features = ["url"] }
10+
spin-factor-outbound-networking = { path = "../factor-outbound-networking" }
11+
spin-factors = { path = "../factors" }
12+
spin-core = { path = "../core" }
13+
spin-world = { path = "../world" }
14+
table = { path = "../table" }
15+
tokio = { version = "1.0", features = ["sync"] }
16+
tracing = { workspace = true }
17+
18+
[dev-dependencies]
19+
spin-factor-variables = { path = "../factor-variables" }
20+
spin-factors-test = { path = "../factors-test" }
21+
tokio = { version = "1", features = ["macros", "rt"] }
22+
23+
[lints]
24+
workspace = true
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
use std::{sync::Arc, time::Duration};
2+
3+
use anyhow::Result;
4+
use spin_core::{async_trait, wasmtime::component::Resource};
5+
use spin_factor_outbound_networking::OutboundAllowedHosts;
6+
use spin_world::v2::mqtt::{self as v2, Connection, Error, Qos};
7+
use tracing::{instrument, Level};
8+
9+
#[async_trait]
10+
pub trait ClientCreator: Send + Sync {
11+
fn create(
12+
&self,
13+
address: String,
14+
username: String,
15+
password: String,
16+
keep_alive_interval: Duration,
17+
) -> Result<Arc<dyn MqttClient>, Error>;
18+
}
19+
20+
pub struct InstanceState {
21+
allowed_hosts: OutboundAllowedHosts,
22+
connections: table::Table<Arc<dyn MqttClient>>,
23+
create_client: Arc<dyn ClientCreator>,
24+
}
25+
26+
impl InstanceState {
27+
pub fn new(allowed_hosts: OutboundAllowedHosts, create_client: Arc<dyn ClientCreator>) -> Self {
28+
Self {
29+
allowed_hosts,
30+
create_client,
31+
connections: table::Table::new(1024),
32+
}
33+
}
34+
}
35+
36+
#[async_trait]
37+
pub trait MqttClient: Send + Sync {
38+
async fn publish_bytes(&self, topic: String, qos: Qos, payload: Vec<u8>) -> Result<(), Error>;
39+
}
40+
41+
impl InstanceState {
42+
async fn is_address_allowed(&self, address: &str) -> Result<bool> {
43+
self.allowed_hosts.check_url(address, "mqtt").await
44+
}
45+
46+
async fn establish_connection(
47+
&mut self,
48+
address: String,
49+
username: String,
50+
password: String,
51+
keep_alive_interval: Duration,
52+
) -> Result<Resource<Connection>, Error> {
53+
self.connections
54+
.push((self.create_client).create(address, username, password, keep_alive_interval)?)
55+
.map(Resource::new_own)
56+
.map_err(|_| Error::TooManyConnections)
57+
}
58+
59+
async fn get_conn(&self, connection: Resource<Connection>) -> Result<&dyn MqttClient, Error> {
60+
self.connections
61+
.get(connection.rep())
62+
.ok_or(Error::Other(
63+
"could not find connection for resource".into(),
64+
))
65+
.map(|c| c.as_ref())
66+
}
67+
}
68+
69+
impl v2::Host for InstanceState {
70+
fn convert_error(&mut self, error: Error) -> Result<Error> {
71+
Ok(error)
72+
}
73+
}
74+
75+
#[async_trait]
76+
impl v2::HostConnection for InstanceState {
77+
#[instrument(name = "spin_outbound_mqtt.open_connection", skip(self, password), err(level = Level::INFO), fields(otel.kind = "client"))]
78+
async fn open(
79+
&mut self,
80+
address: String,
81+
username: String,
82+
password: String,
83+
keep_alive_interval: u64,
84+
) -> Result<Resource<Connection>, Error> {
85+
if !self
86+
.is_address_allowed(&address)
87+
.await
88+
.map_err(|e| v2::Error::Other(e.to_string()))?
89+
{
90+
return Err(v2::Error::ConnectionFailed(format!(
91+
"address {address} is not permitted"
92+
)));
93+
}
94+
self.establish_connection(
95+
address,
96+
username,
97+
password,
98+
Duration::from_secs(keep_alive_interval),
99+
)
100+
.await
101+
}
102+
103+
/// Publish a message to the MQTT broker.
104+
///
105+
/// OTEL trace propagation is not directly supported in MQTT V3. You will need to embed the
106+
/// current trace context into the payload yourself.
107+
/// https://w3c.github.io/trace-context-mqtt/#mqtt-v3-recommendation.
108+
#[instrument(name = "spin_outbound_mqtt.publish", skip(self, connection, payload), err(level = Level::INFO),
109+
fields(otel.kind = "producer", otel.name = format!("{} publish", topic), messaging.operation = "publish",
110+
messaging.system = "mqtt"))]
111+
async fn publish(
112+
&mut self,
113+
connection: Resource<Connection>,
114+
topic: String,
115+
payload: Vec<u8>,
116+
qos: Qos,
117+
) -> Result<(), Error> {
118+
let conn = self.get_conn(connection).await.map_err(other_error)?;
119+
120+
conn.publish_bytes(topic, qos, payload).await?;
121+
122+
Ok(())
123+
}
124+
125+
fn drop(&mut self, connection: Resource<Connection>) -> anyhow::Result<()> {
126+
self.connections.remove(connection.rep());
127+
Ok(())
128+
}
129+
}
130+
131+
pub fn other_error(e: impl std::fmt::Display) -> Error {
132+
Error::Other(e.to_string())
133+
}
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
mod host;
2+
3+
use std::sync::Arc;
4+
use std::time::Duration;
5+
6+
use host::other_error;
7+
use host::InstanceState;
8+
use rumqttc::{AsyncClient, Event, Incoming, Outgoing, QoS};
9+
use spin_core::async_trait;
10+
use spin_factor_outbound_networking::OutboundNetworkingFactor;
11+
use spin_factors::{
12+
ConfigureAppContext, Factor, InstanceBuilders, PrepareContext, RuntimeFactors,
13+
SelfInstanceBuilder,
14+
};
15+
use spin_world::v2::mqtt::{self as v2, Error, Qos};
16+
use tokio::sync::Mutex;
17+
18+
pub use host::{ClientCreator, MqttClient};
19+
20+
pub struct OutboundMqttFactor {
21+
create_client: Arc<dyn ClientCreator>,
22+
}
23+
24+
impl OutboundMqttFactor {
25+
pub fn new(create_client: Arc<dyn ClientCreator>) -> Self {
26+
Self { create_client }
27+
}
28+
}
29+
30+
impl Factor for OutboundMqttFactor {
31+
type RuntimeConfig = ();
32+
type AppState = ();
33+
type InstanceBuilder = InstanceState;
34+
35+
fn init<T: Send + 'static>(
36+
&mut self,
37+
mut ctx: spin_factors::InitContext<T, Self>,
38+
) -> anyhow::Result<()> {
39+
ctx.link_bindings(spin_world::v2::mqtt::add_to_linker)?;
40+
Ok(())
41+
}
42+
43+
fn configure_app<T: RuntimeFactors>(
44+
&self,
45+
_ctx: ConfigureAppContext<T, Self>,
46+
) -> anyhow::Result<Self::AppState> {
47+
Ok(())
48+
}
49+
50+
fn prepare<T: RuntimeFactors>(
51+
&self,
52+
_ctx: PrepareContext<Self>,
53+
builders: &mut InstanceBuilders<T>,
54+
) -> anyhow::Result<Self::InstanceBuilder> {
55+
let allowed_hosts = builders
56+
.get_mut::<OutboundNetworkingFactor>()?
57+
.allowed_hosts();
58+
Ok(InstanceState::new(
59+
allowed_hosts,
60+
self.create_client.clone(),
61+
))
62+
}
63+
}
64+
65+
impl SelfInstanceBuilder for InstanceState {}
66+
67+
// This is a concrete implementation of the MQTT client using rumqttc.
68+
pub struct NetworkedMqttClient {
69+
inner: rumqttc::AsyncClient,
70+
event_loop: Mutex<rumqttc::EventLoop>,
71+
}
72+
73+
const MQTT_CHANNEL_CAP: usize = 1000;
74+
75+
impl NetworkedMqttClient {
76+
pub fn create(
77+
address: String,
78+
username: String,
79+
password: String,
80+
keep_alive_interval: Duration,
81+
) -> Result<Self, Error> {
82+
let mut conn_opts = rumqttc::MqttOptions::parse_url(address).map_err(|e| {
83+
tracing::error!("MQTT URL parse error: {e:?}");
84+
Error::InvalidAddress
85+
})?;
86+
conn_opts.set_credentials(username, password);
87+
conn_opts.set_keep_alive(keep_alive_interval);
88+
let (client, event_loop) = AsyncClient::new(conn_opts, MQTT_CHANNEL_CAP);
89+
Ok(Self {
90+
inner: client,
91+
event_loop: Mutex::new(event_loop),
92+
})
93+
}
94+
}
95+
96+
#[async_trait]
97+
impl MqttClient for NetworkedMqttClient {
98+
async fn publish_bytes(&self, topic: String, qos: Qos, payload: Vec<u8>) -> Result<(), Error> {
99+
let qos = match qos {
100+
Qos::AtMostOnce => rumqttc::QoS::AtMostOnce,
101+
Qos::AtLeastOnce => rumqttc::QoS::AtLeastOnce,
102+
Qos::ExactlyOnce => rumqttc::QoS::ExactlyOnce,
103+
};
104+
// Message published to EventLoop (not MQTT Broker)
105+
self.inner
106+
.publish_bytes(topic, qos, false, payload.into())
107+
.await
108+
.map_err(other_error)?;
109+
110+
// Poll event loop until outgoing publish event is iterated over to send the message to MQTT broker or capture/throw error.
111+
// We may revisit this later to manage long running connections, high throughput use cases and their issues in the connection pool.
112+
let mut lock = self.event_loop.lock().await;
113+
loop {
114+
let event = lock
115+
.poll()
116+
.await
117+
.map_err(|err| v2::Error::ConnectionFailed(err.to_string()))?;
118+
119+
match (qos, event) {
120+
(QoS::AtMostOnce, Event::Outgoing(Outgoing::Publish(_)))
121+
| (QoS::AtLeastOnce, Event::Incoming(Incoming::PubAck(_)))
122+
| (QoS::ExactlyOnce, Event::Incoming(Incoming::PubComp(_))) => break,
123+
124+
(_, _) => continue,
125+
}
126+
}
127+
Ok(())
128+
}
129+
}

0 commit comments

Comments
 (0)